QueueingModelDataService.java [src/csip] Revision:   Date:
/*
 * $Id$
 *
 * This file is part of the Cloud Services Integration Platform (CSIP),
 * a Model-as-a-Service framework, API and application suite.
 *
 * 2012-2022, Olaf David and others, OMSLab, Colorado State University.
 *
 * OMSLab licenses this file to you under the MIT license.
 * See the LICENSE file in the project root for more information.
 */
package csip;

import csip.api.server.ServiceException;
import csip.utils.Client;
import com.mongodb.MongoClient;
import com.mongodb.MongoClientURI;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.UpdateOptions;
import com.mongodb.client.result.UpdateResult;
import static csip.ModelDataService.CANCELED;
import static csip.ModelDataService.FAILED;
import static csip.ModelDataService.FINISHED;
import static csip.ModelDataService.KEY_METAINFO;
import static csip.ModelDataService.KEY_SUUID;
import static csip.ModelDataService.TIMEDOUT;
import csip.annotations.Author;
import csip.annotations.Description;
import csip.annotations.Documentation;
import csip.annotations.License;
import csip.annotations.Name;
import csip.annotations.State;
import csip.annotations.VersionInfo;
import csip.utils.SimpleCache;
import javax.ws.rs.Path;
import java.net.URISyntaxException;
import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.ws.rs.PathParam;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.common.errors.WakeupException;
import org.bson.Document;
import org.codehaus.jettison.json.JSONObject;

@Name("pubsub")
@Description("Publish/Subscribe")
@Documentation("https://alm.engr.colostate.edu/csip")
@State(State.STABLE)
@License(License.MIT)
@VersionInfo("$Id$")
@Author(name = "od", email = "<odavid@colostate.edu>", org = "CSU")
@Path("p/pubsub/{delegate:.*}")
public class QueueingModelDataService extends ModelDataService {

  static final String KEY_QUEUE_POS = "queue_pos";

  static String delegateUrl = Config.getString("csip.pubsub.delegate.url");
  static boolean needsWebHook = Config.getBoolean("csip.pubsub.webhook.payload", true);
  static int queueLen = Config.getInt("csip.pubsub.queue.len", Integer.MAX_VALUE);
  static int queueRemainingLen = Config.getInt("csip.pubsub.queue.remaining.len", 25);

  static Logger l = Config.LOG;

  private static QueueManagement mgmt;


  static {
    try {
      mgmt = new QueueManagement();
    } catch (Exception E) {
      l.log(Level.SEVERE, "Init error:", E);
    }
  }

  @PathParam("delegate")
  String delegate;

  boolean isQueued = false;


  @Override
  boolean isQueued() {
    return isQueued;
  }


  @Override
  protected void doProcess() throws Exception {
    if (delegate == null || delegate.isEmpty())
      throw new ServiceException("No delegate service provided.");

    // Debugging/status
    if (delegate.equals("status")) {
      results().put("req_in_queue", mgmt.getQueueLen(), "number of requests currently in queue.")
          .put("queue_remaining", mgmt.getRemainingCapacity(), "remaining queue capacity.")
          .put("queue_capacity", queueLen, "total queue capacity.")
          //          .put("queue_min_remaining", queueRemainingLen)
          .put("res_in_publish_queue", mgmt.deliveryQueue.size(), "number of responses in publish queue.")
          .put("incoming", mgmt.incoming.get(), "total number of externally received requests.")
          .put("queued_sub", mgmt.queued_sub.get(), "total number of queued requests.")
          //          .put("queued_rec", mgmt.queued_rec.get())
          .put("queued_back", mgmt.queued_back.get(), "total number of requests put back because of backend capacity at max.")
          .put("exec_sub", mgmt.exec_sub.get(), "total number of requests submitted for backend execution.")
          .put("exec_rec", mgmt.exec_rec.get(), "total number of responses received from backend.")
          //          .put("webhook_sub", mgmt.webhook_sub.get())
          .put("webhook_sub_failed", mgmt.webhook_sub_failed.get(), "total number of responses failed for webhook submission.")
          .put("webhook_rec", mgmt.webhook_rec.get(), "total number of responses submitted with webhook.")
          .put("openQueue", mgmt.queueOpen.get(), "is the queue open or not.");
      return;
    }
    if (delegate.equals("reset")) {
      synchronized (this) {
        mgmt.queue.clear();
        mgmt.sn.set(0);
      }
      results().put("ok", true);
      return;
    }
    if (delegate.equals("payloads")) {
      int i = 0;
      for (QueueManagement.Payload payload : mgmt.queue) {
        results().put(i++ + ": " + payload.url, payload.request);
      }
      results().put("ok", true);
      return;
    }
    if (delegate.equals("toggle")) {
      mgmt.queueOpen.set(!mgmt.queueOpen.get());
      results().put("queueOpen", mgmt.queueOpen.get());
      return;
    }

    if (l.isLoggable(Level.INFO))
      l.log(Level.INFO, delegate);

    if (!mgmt.queueOpen.get())
      throw new ServiceException("Queue closed for submission, try again later.");

    if (mgmt.getRemainingCapacity() < queueRemainingLen)
      throw new ServiceException("Queue capacity reached, try again later.");

    if (needsWebHook && !metainfo().hasName(KEY_WEBHOOK))
      throw new ServiceException("'webhook' metainfo missing.");

    if (delegateUrl == null) {
      String u = request().getURL();
      delegateUrl = u.substring(0, u.indexOf(":"));
    }

    String delegateService = delegateUrl + ":" + delegate;

    JSONObject v = new JSONObject(request().getRequest().toString());

    JSONObject mi = v.getJSONObject(KEY_METAINFO);
    mi.put(KEY_MODE, ASYNC);
    mi.remove("cloud_node");
    mi.remove("status");
    mi.remove("tstamp");
    mi.remove("request_ip");
    mi.put("csip-auth", request().getAuthToken());

    try {
      if (mgmt.checkTarget) {
        // check if target service is available
        long p = Client.ping(delegateService, mgmt.pingTimeout);
        if (p == -1)
          throw new ServiceException("Target service not available: " + delegateService);
      }
      mgmt.incoming.incrementAndGet();
      long pos = mgmt.queue(delegateService, v.toString());
      if (pos == -1)
        throw new ServiceException("Error queueing the service, try again later.");

      metainfo().put(KEY_QUEUE_POS, pos);
      if (l.isLoggable(Level.INFO))
        l.log(Level.INFO, "QUEUE POS, " + pos);

      isQueued = true;
    } catch (Exception E) {
      throw new ServiceException("Error queueing the service", E);
    }
  }

  /**
   * QueueManagement
   */
  private static class QueueManagement {

    final String STRING_SER = "org.apache.kafka.common.serialization.StringSerializer";
    final String STRING_DESER = "org.apache.kafka.common.serialization.StringDeserializer";

    String bootstrap_servers = Config.getString("csip.pubsub.kafka.bootstrap_servers");
    long consumerPoll = Config.getLong("csip.pubsub.kafka.consumer.poll.ms", 10000);
    String consumerGroupId = Config.getString("csip.pubsub.kafka.consumer.group.id", "test-consumer-group");
    String resultTopic = Config.getString("csip.pubsub.result.topic", "8086");
    long submitDelay = Config.getLong("csip.pubsub.submit.delay.ms", 1000);
    int defaultCapacity = Config.getInt("csip.pubsub.default.capacity", 8);
    int pingTimeout = Config.getInt("csip.pubsub.ping.timeout.ms", 1000);
    long delayAtCapacilty = Config.getLong("csip.pubsub.atcapacity.delay.ms", 2000);
    boolean checkTarget = Config.getBoolean("csip.pubsub.check.target", false);
    long offerMS = Config.getLong("csip.pubsub.offer.ms", 500);
    long pollMS = Config.getLong("csip.pubsub.poll.ms", 2000);
    long loadcheck = Config.getLong("csip.pubsub.loadcheck.ms", 2000);

    String connect = Config.getString("csip.pubsub.stats", null);

    Consumer<String, String> receiveConsumer = getResultConsumer();

    ExecutorService executor = Executors.newCachedThreadPool();
    ScheduledExecutorService ses = Executors.newSingleThreadScheduledExecutor();

    class LoadProbe implements Runnable {

      Client cl = new Client(l);
      int MAX_TTL = 10;

      private class Load {

        Integer load;
        AtomicInteger ttl = new AtomicInteger(MAX_TTL);
      }

      // service context > current load
      Map<String, Load> loads = new ConcurrentHashMap<>();
      Map<String, String> sh = new HashMap<>();


      @Override
      public void run() {
        try {
          if (threadsRunning.get())
            update();
        } catch (Exception ex) {
          l.log(Level.SEVERE, null, ex);
        }
      }


      void close() {
        cl.close();
      }


      private void update() throws Exception {
        if (l.isLoggable(Level.INFO))
          l.info("Backend update.");

        for (Map.Entry<String, Load> entry : loads.entrySet()) {
          String context = entry.getKey();
          int ttl = entry.getValue().ttl.decrementAndGet();
          if (ttl <= 0) {
            loads.remove(context);
            if (l.isLoggable(Level.INFO))
              l.info("removed probe for : " + context);
          } else {
            Integer i = query(context);
            entry.getValue().load = i;
            if (l.isLoggable(Level.INFO))
              l.info("probe: " + context + " -> " + i + ", " + ttl);
          }
        }
      }


      private synchronized Integer query(String s) throws Exception {
        try {
          String result = cl.doGET(s + "/q/running");
          return Integer.valueOf(result);
        } catch (Exception E) {
          l.log(Level.SEVERE, "Error getting the current running services: ", E);
          return Integer.MAX_VALUE;
        }
      }


      int getCurrentLoad(String service) throws Exception {
        String context = getContext(service);
        Load load = loads.get(context);
        if (load == null) {
          load = new Load();
          load.load = query(context);
          loads.put(context, load);
        }
        // reset the use tick
        load.ttl.set(MAX_TTL);
        return load.load;
      }


      String getContext(String service) throws URISyntaxException {
        String context = sh.get(service);
        if (context == null) {
          String[] u = Utils.getURIParts(service);
          sh.put(service, context = u[0] + u[1] + u[2] + "/" + u[3]);
        }
        return context;
      }
    }

    LoadProbe probe = new LoadProbe();

    FutureTask<String> submitTask = new FutureTask<>(new SubmitJobThread());
    FutureTask<String> receiveTask = new FutureTask<>(new ReceiveJobStatusThread());
    FutureTask<String> deliveryTask = new FutureTask<>(new PublishJobThread());

    final AtomicBoolean threadsRunning = new AtomicBoolean(true);

    SimpleCache<String, Integer> capacities = new SimpleCache<>();

    //stats
    AtomicInteger incoming = new AtomicInteger(0);
    AtomicInteger queued_sub = new AtomicInteger(0);
    AtomicInteger queued_rec = new AtomicInteger(0);
    AtomicInteger queued_back = new AtomicInteger(0);
    AtomicInteger exec_sub = new AtomicInteger(0);
    AtomicInteger exec_rec = new AtomicInteger(0);
    AtomicInteger webhook_sub = new AtomicInteger(0);
    AtomicInteger webhook_sub_failed = new AtomicInteger(0);
    AtomicInteger webhook_rec = new AtomicInteger(0);
    AtomicInteger sn = new AtomicInteger(0);

    BlockingQueue<Payload> queue = new LinkedBlockingQueue<>(queueLen);
    BlockingQueue<String> deliveryQueue = new LinkedBlockingQueue<>();

    final AtomicBoolean queueOpen = new AtomicBoolean(true);

    Stats stats;

    static class Payload {

      String url;
      String request;


      Payload(String url, String request) {
        this.url = url;
        this.request = request;
      }
    }


    public int getQueueLen() {
      return queue.size();
    }


    public int getRemainingCapacity() {
      return queue.remainingCapacity();
    }


    Properties getConsumerProperties() {
      Properties p = new Properties();
      p.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrap_servers);
      p.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
      p.put(ConsumerConfig.GROUP_ID_CONFIG, consumerGroupId);
      p.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, STRING_DESER);
      p.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, STRING_DESER);
      return p;
    }


    Consumer<String, String> getResultConsumer() {
      Consumer<String, String> c = new KafkaConsumer<>(getConsumerProperties());
      c.subscribe(Arrays.asList(resultTopic));
      return c;
    }


    synchronized int queue(String url, String request) throws Exception {
      if (l.isLoggable(Level.INFO))
        l.log(Level.INFO, "Queueing  :{0} {1}", new Object[]{url, request});

      boolean s = queue.offer(new Payload(url, request), offerMS, TimeUnit.MILLISECONDS);
      if (!s)
        return -1;

      queued_sub.getAndIncrement();
      return getQueueLen();
    }


    /**
     * static capacity, based on property settings.
     */
    int getContextCapacity(String context) {
      return capacities.get(context, c
          -> Config.getInt("csip.pubsub."
              + c.replace('/', '.').replace(':', '.') + ".capacity", defaultCapacity));
    }

    /**
     * Submit for execution.
     *
     * This thread pulls the entries from the queue and submits it for
     * execution.
     *
     */
    class SubmitJobThread implements Callable<String> {

      Client cl = new Client(l);
      long delay = submitDelay;


      private void executeAsync(String url, String payload, int capacity, Client c) {
        try {
          JSONObject o = new JSONObject(payload);
          Map<String, String> header = new HashMap<>();
          header.put(KEY_SUUID, o.getJSONObject(KEY_METAINFO).getString(KEY_SUUID));
          o.getJSONObject("metainfo").put("cap", capacity);
          o.getJSONObject("metainfo").put("sn", sn.get());
          sn.incrementAndGet();

          JSONObject result = c.doPOST(url, o, header);
          if (l.isLoggable(Level.FINE))
            l.log(Level.FINE, "POST Run to " + url + " ... received: " + result.toString());

          exec_sub.incrementAndGet();
        } catch (Exception ex) {
          l.log(Level.SEVERE, null, ex);
        }
      }


      private void submit(Client cl, String serviceUrl, String servicePayload) throws Exception {
        // wait a bit before continue processing
        try {
          Thread.sleep(delay);
        } catch (InterruptedException ex) {
          l.log(Level.INFO, "Interrupted");
        }
        if (checkTarget) {
          // Ping the service first.
          long p = Client.ping(serviceUrl, pingTimeout);
          if (p == -1) {
            queue(serviceUrl, servicePayload);
            delay = delayAtCapacilty;
            l.log(Level.INFO, "Cannot ping the service, back in line...");
            return;
          }
        }

        // check the current load in the backend.
        int currentLoad = probe.getCurrentLoad(serviceUrl);
        int contextCapacity = getContextCapacity(probe.getContext(serviceUrl));
        l.log(Level.INFO, "Load for {2}: {0}/{1}",
            new Object[]{currentLoad, contextCapacity, serviceUrl});

        // compare the current backend load against the backend capacity
        if (currentLoad >= contextCapacity) {
          // capacity reached, put it back in line
          queue(serviceUrl, servicePayload);
          queued_back.incrementAndGet();
          delay = delayAtCapacilty;
          l.log(Level.WARNING, "back in line...{0}, {1}/{2}",
              new Object[]{serviceUrl, currentLoad, contextCapacity});
        } else {
          // capacity is fine, submit for execution.
          queued_rec.incrementAndGet();
          executeAsync(serviceUrl, servicePayload, currentLoad, cl);
          delay = submitDelay;
        }
      }


      @Override
      public String call() throws Exception {
        try {
          while (threadsRunning.get()) {
            Payload payload = queue.poll(pollMS, TimeUnit.MILLISECONDS);
            if (payload != null) {
              l.log(Level.INFO, "RECEIVED: {0} ", new Object[]{payload.url});
              l.log(Level.FINE, "  Request: {0}", new Object[]{payload.request});
              submit(cl, payload.url, payload.request);
            }
            l.log(Level.INFO, "Submit Alive.");
          }
        } finally {
          cl.close();
          l.log(Level.INFO, "Submitter closed.");
        }
        return "Done Submit.";
      }
    }

    /**
     * Submit for execution.
     *
     * This thread pulls the entries from the queue and submits it for
     * execution.
     *
     */
    class PublishJobThread implements Callable<String> {

      Client cl = new Client(l);


      private void publish(Client c, String result) throws Exception {

        // call webhook
        JSONObject o = new JSONObject(result);
        if (o.has(KEY_METAINFO)) {
          String url = o.getJSONObject(KEY_METAINFO).optString(KEY_WEBHOOK);
          if (!url.isEmpty()) {
            l.log(Level.INFO, "Webhook Post to " + url);
            webhook_sub.incrementAndGet();
            String ack = c.doPOST(url, result);
            if (ack != null) {
              l.log(Level.INFO, "Delivered and Acknowledged: " + ack);
              webhook_rec.incrementAndGet();
            }
            if (mgmt.stats != null) {
              l.log(Level.INFO, "Calling statistics: ");
              String auth = o.getJSONObject(KEY_METAINFO).optString("csip-auth");
              String service = o.getJSONObject(KEY_METAINFO).optString("service_url");
              long cpu_time = o.getJSONObject(KEY_METAINFO).optLong("cpu_time");
              mgmt.stats.inc(l, service, cpu_time, auth);
            }
          } else {
            l.log(Level.WARNING, "No webhook found: '" + url + "'");
          }

        } else {
          webhook_sub_failed.incrementAndGet();
          l.log(Level.SEVERE, "PublishError for :" + result);
        }
      }


      @Override
      public String call() throws Exception {
        try {
          while (threadsRunning.get()) {
            String payload = deliveryQueue.poll(pollMS, TimeUnit.MILLISECONDS);
            if (payload != null) {
              l.log(Level.INFO, "RECEIVED FOR PUBLISH: {0} ", new Object[]{payload});
              publish(cl, payload);
            }
            l.log(Level.INFO, "Publish Alive.");
          }
        } finally {
          cl.close();
          l.log(Level.INFO, "Publisher closed.");
        }
        return "Done Publish.";
      }
    }

    /**
     * ReceiveJobStatusThread. This thread receives the result status messages
     * from the services.
     *
     */
    class ReceiveJobStatusThread implements Callable<String> {

      Duration d = Duration.ofMillis(consumerPoll);
      Client cl = new Client(l);


      private void queryResults(String status, String suid_url, Client c) {
        if (status.equals(FINISHED)
            || status.equals(FAILED)
            || status.equals(CANCELED)
            || status.equals(TIMEDOUT)) {
          String[] v = suid_url.split("\\s+");
          exec_rec.incrementAndGet();
          try {
            // call Q service
            String[] u = Utils.getURIParts(v[1]);
            String url1 = u[0] + u[1] + u[2] + "/" + u[3] + "/q/" + v[0];
            l.log(Level.INFO, "Query Results " + url1);
            String result = c.doGET(url1);
            if (l.isLoggable(Level.FINE))
              l.log(Level.FINE, "Received RESULT for:  " + url1 + " " + result);

            deliveryQueue.put(result);
          } catch (Exception ex) {
            l.log(Level.SEVERE, null, ex);
          }
        }
      }


      @Override
      public String call() throws Exception {
        try {
          while (threadsRunning.get()) {
            ConsumerRecords<String, String> records = receiveConsumer.poll(d);
            if (records.count() > 0) {
              records.forEach(record -> {
                l.log(Level.INFO, "{0} RECEIVED: {1} {2}",
                    new Object[]{record.offset(), record.key(), record.value()});
                queryResults(record.key(), record.value(), cl);
              });
              receiveConsumer.commitSync();
            }
            l.log(Level.INFO, "Receive Alive.");
          }
        } catch (WakeupException E) {
          if (threadsRunning.get())
            throw E;
        } finally {
          receiveConsumer.close();
          cl.close();
          l.log(Level.INFO, "Receiver closed.");
        }
        return "Done Receive.";
      }
    }

    static class Stats {

      MongoClient mongo;
      MongoDatabase db;
      UpdateOptions opt = new UpdateOptions().upsert(true);

      static final Document INC = new Document("$inc", new Document("count", 1l));


      Stats(String uri) {
        MongoClientURI u = new MongoClientURI(uri);
        String dbname = u.getDatabase();
        if (dbname == null)
          dbname = "pubsub";

        mongo = new MongoClient(u);
        db = mongo.getDatabase(dbname);
      }


      void inc(Logger l, String serviceUrl, long time, String collection) {
        if (collection.isEmpty()) {
          l.warning("No auth/collection for  " + serviceUrl);
          return;
        }
        MongoCollection<Document> c = db.getCollection(collection);
        UpdateResult result = c.updateOne(Filters.eq("service", serviceUrl), INC);
        if (result.getModifiedCount() == 0) {
          c.insertOne(new Document("service", serviceUrl));
          c.updateOne(Filters.eq("service", serviceUrl), INC);
        }
        c.updateOne(Filters.eq("service", serviceUrl),
            new Document("$inc", new Document("time", time)));
      }


      void close() {
        mongo.close();
      }
    }


    void shutdown() {
      threadsRunning.set(false);
      receiveConsumer.wakeup();
      try {
        l.log(Level.INFO, receiveTask.get());
        l.log(Level.INFO, submitTask.get());
        l.log(Level.INFO, deliveryTask.get());
      } catch (InterruptedException | ExecutionException ex) {
        l.log(Level.SEVERE, null, ex);
      }
      executor.shutdown();
      if (connect != null)
        stats.close();

      ses.shutdown();
      probe.close();
    }


    void startup() {
      executor.submit(submitTask);
      executor.submit(receiveTask);
      executor.submit(deliveryTask);
      ses.scheduleWithFixedDelay(probe, 2000, loadcheck, TimeUnit.MILLISECONDS);
      if (connect != null)
        stats = new Stats(connect);
    }
  }


  public static void onContextInit() {
    try {
      mgmt.startup();
      l.log(Level.INFO, "Started Pub/Sub Threads.");
    } catch (Exception E) {
      l.log(Level.SEVERE, null, E);
    }
  }


  public static void onContextDestroy() {
    try {
      mgmt.shutdown();
      l.log(Level.INFO, "Stopped Pub/Sub Threads.");
    } catch (Exception E) {
      l.log(Level.SEVERE, null, E);
    }
  }
}