/* Copyright (c) 2009 & onwards. MapR Tech, Inc., All rights reserved */

package com.mapr.streams.producer;

import java.lang.System;
import java.io.*;
import java.nio.ByteBuffer;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.Properties;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.ProducerConfig;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;

public class MarlinProducerMultiThreadPerformance {

  public static String streamName;
  public static int numTopics = 2;
  public static int numPartitions = 4;
  public static int numMsgsPerPartition = 100000;
  public static int numBatches = 1;
  public static long batchSleepMs = 10*1000;
  public static boolean multipleFlushers = true;
  public static boolean printProgress = true;
  public static int totalNumMsgs;
  public static KafkaProducer producer;
  public static final int MSG_VALUE_LENGTH = 200;

  public static void usage() {
    System.err.println("MarlinProducerMultiThreadPerformance -path <stream-full-name> [-ntopics <num topics>");
    System.err.println(" -npart <numpartitions per topic> -nmsgs <num messages per topicfeed> -multiflush <enable multiple flushers>");
    System.err.println(" -nbatches <number of batches> -batchsleepms <milliseconds to sleep between batches> ]");
    System.exit(1);
  }

  public static void main(String[] args) throws IOException {
    for (int i = 0; i < args.length; ++i) {
      if (args[i].equals("-path")) {
        i++;
        if (i >= args.length) usage();
        streamName = args[i];
      } else if (args[i].equals("-nmsgs")) {
        i++;
        if (i >= args.length) usage();
        numMsgsPerPartition = Integer.parseInt(args[i]);
      } else if (args[i].equals("-ntopics")) {
        i++;
        if (i >= args.length) usage();
        numTopics = Integer.parseInt(args[i]);
      } else if (args[i].equals("-npart")) {
        i++;
        if (i >= args.length) usage();
        numPartitions = Integer.parseInt(args[i]);
      } else if (args[i].equals("-multiflush")) {
        i++;
        if (i >= args.length) usage();
        multipleFlushers = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-nbatches")) {
        i++;
        if (i >= args.length) usage();
        numBatches = Integer.parseInt(args[i]);
      } else if (args[i].equals("-batchsleepms")) {
        i++;
        if (i >= args.length) usage();
        batchSleepMs = Long.parseLong(args[i]);
      } else if (args[i].equals("-progress")) {
        i++;
        if (i >= args.length) usage();
        printProgress = Boolean.parseBoolean(args[i]);
      } else {
        usage();
      }
    }

    runTest();
  }

  public static boolean runStressTest(String stream,
                                      int nmsgs,
                                      int ntopics,
                                      int npartitions,
                                      int nbatches,
                                      long sleepms,
                                      boolean mflushers,
                                      boolean progress) throws IOException {
    streamName = stream;
    numMsgsPerPartition = nmsgs;
    numTopics = ntopics;
    numPartitions = npartitions;
    multipleFlushers = mflushers;
    numBatches = nbatches;
    batchSleepMs = sleepms;
    printProgress = progress;
    return runTest();
  }

  public static boolean runTest() {
    if (streamName == null || streamName.length() == 0) {
      System.err.println("stream name cannot be empty.");
      usage();
    }

    if (numPartitions <= 0) {
      System.err.println("num partitions cannot be negative or zero.");
      usage();
    }

    if (numTopics <= 0) {
      System.err.println("num topics cannot be negative or zero.");
    }

    totalNumMsgs = numMsgsPerPartition * numPartitions * numTopics * numBatches;

    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    //props.put("buffer.memory", 16 * 1024 * 1024);
    props.put(cdef.getParallelFlushersPerPartition(), multipleFlushers);
    props.put(cdef.getBufferTime(), 3*1000);
    props.put(cdef.getMetadataMaxAge(), 5*60*1000);

    long sleepTime = 0;

    producer = new KafkaProducer<byte[], byte[]>(props);
    PerfStats stats = new PerfStats(totalNumMsgs);


    // Now get all the topic feeds to send messages to.
    Thread[] partitionSenders = new Thread[numTopics*numPartitions];
    for(int i = 0; i < numTopics; ++i) {
      String topicName = streamName+":topic" + i;
      for (int j = 0; j < numPartitions; ++j) {
        PartitionSender currentPartSender = new PartitionSender(producer, stats, topicName, j,
                                                                numBatches, numMsgsPerPartition, batchSleepMs);
        partitionSenders[i*numPartitions + j] = new Thread(currentPartSender);
        partitionSenders[i*numPartitions + j].start();
      }
    }

    try {
      for(int i = 0; i < partitionSenders.length; i++) {
        partitionSenders[i].join();
      }
    } catch ( InterruptedException e) {
      System.out.println("***** Joining partitionSenders failed *****");
    }

    producer.close();
    boolean verify = stats.checkAndVerify();
    stats.printReport(numTopics, numPartitions, sleepTime);

    return verify;
  }

  private static final class PartitionSender implements Runnable {
    private String topicName;
    private int partitionId;
    private String keyPrefix;
    private KafkaProducer producer;
    private PerfStats stats;
    private int numBatches;
    private int numMsgsPerBatch;
    private long sleepTime;

    public PartitionSender(KafkaProducer prod, PerfStats ss, String tn, int pid, int nb, int nmspb, long st) {
      this.topicName = tn;
      this.partitionId = pid;
      this.producer = prod;
      this.stats = ss;
      this.numBatches = nb;
      this.numMsgsPerBatch = nmspb;
      this.sleepTime = st;

      keyPrefix = topicName+":"+partitionId;
    }

    public void run() {
      for (int batchIdx = 0; batchIdx < numBatches; ++batchIdx) {

      int msgIdx = batchIdx*numMsgsPerBatch;

        for (int i = 0; i < numMsgsPerPartition; ++i) {
          int keyIdx = msgIdx + i;
          byte[] key = (keyPrefix+":"+keyIdx).getBytes();
          int keySz = key.length;

          byte[] value = new byte[MSG_VALUE_LENGTH];
          ByteBuffer valueBuffer = ByteBuffer.wrap(value);
          long curTime = System.currentTimeMillis();
          valueBuffer.putLong(curTime);
          valueBuffer.put(key);
          while (true) {
            if (valueBuffer.position() > (MSG_VALUE_LENGTH - 2))
              break;
            valueBuffer.putChar('a');
          }

          long sendStart = System.currentTimeMillis();
          long sendBytes = keySz + MSG_VALUE_LENGTH;
          Callback cb = new PerfCallback(sendStart, sendBytes, stats);

          // For this test, we will specify the partition in partIdx.
          ProducerRecord<byte[], byte[]> record =
            new ProducerRecord<byte[], byte[]>(topicName, partitionId, key, value);
          producer.send(record, cb);
        }

        if (batchIdx < numBatches - 1) {
          long sleepStart = System.currentTimeMillis();
          try {
            Thread.sleep(batchSleepMs);
          } catch (InterruptedException e) {}
          sleepTime += (System.currentTimeMillis() - sleepStart);
        }
      }

      producer.flush();
    }
  }

  private static final class PerfStats {
    private long startTime;
    private long endTime;
    private long minLatency;
    private long maxLatency;
    private long totalLatency;
    private long totalBytes;
    private long msgCount;
    private long totalMsgsToExpect;
    private long lastProgressTime;
    private long lastProgressMsgs;
    private int nsecs;

    public PerfStats(int totalMsgs) {
      this.totalMsgsToExpect = totalMsgs;
      this.minLatency = 999999;
      this.maxLatency = 0;
      this.totalBytes = 0;
      this.msgCount = 0;
      this.totalLatency = 0;
      this.endTime = -1;
      this.startTime = System.currentTimeMillis();
      this.lastProgressTime = this.startTime;
      this.lastProgressMsgs = 0;
      this.nsecs = 0;
    }

    public synchronized boolean checkAndVerify() {
      if (this.totalMsgsToExpect != this.msgCount) {
        System.out.println("***** Verification failed! " +
                           this.totalMsgsToExpect + " != " +
                           this.msgCount + " *****");
        return false;
      }
      return true;
    }

    public synchronized void report(long now, long latency, long bytes) {
      this.minLatency = Math.min(this.minLatency, latency);
      this.maxLatency = Math.max(this.maxLatency, latency);
      this.totalBytes += bytes;
      this.msgCount++;
      this.totalLatency += latency;

      if (printProgress && (now - this.lastProgressTime >= 1000)) {
        DateFormat df = new SimpleDateFormat("dd/MM/yy HH:mm:ss");
        Date date = new Date();

        System.out.printf("%s %4d secs %9d msgs %8d msgs/s %n",
                          df.format(date),
                          nsecs,
                          this.msgCount,
                          this.msgCount - this.lastProgressMsgs); 

        this.lastProgressMsgs = this.msgCount;
        this.lastProgressTime = now;
        this.nsecs++;
      }

      if (this.msgCount == this.totalMsgsToExpect)
        this.endTime = System.currentTimeMillis();
    }

    public synchronized void printReport(int numTopics, int numPartitions, long sleepTime) {
      System.out.println("***** Producer Report Start *****");
      System.out.println("nTopics/nPartitions: " + numTopics + "/" + numPartitions);
      System.out.println("Expected nMsgs: " + this.totalMsgsToExpect);
      System.out.println("Callback nMsgs: " + this.msgCount);
      if (this.endTime == -1)
        this.endTime = System.currentTimeMillis();

      long elapsedTime = this.endTime - this.startTime;
      long runTime = this.endTime - this.startTime - sleepTime;
      System.out.println("Total time (ms): " + elapsedTime);
      System.out.println("Total sleep time (ms): " + sleepTime);
      System.out.println("Total run time (ms): " + runTime);
      System.out.println("Total bytes sent: " + this.totalBytes);
      System.out.println("Min/Max latency (ms): " + this.minLatency + "/" + this.maxLatency);
      System.out.println("Average latency (ms): " + (this.totalLatency/this.msgCount));
      System.out.println("Average nMsgs/sec: " + this.msgCount*1.0/runTime*1000.0);
      System.out.println("Average nKBs/sec: " + this.totalBytes/runTime*1000.0/1024.0);
      System.out.println("***** Producer Report End *****");
    }
  }

  private static final class PerfCallback implements Callback {

    private final long start;
    private final long bytes;
    private PerfStats stats;

    public PerfCallback(long start, long bytes, PerfStats stats) {
      this.start = start;
      this.bytes = bytes;
      this.stats = stats;
    }

    public void onCompletion(RecordMetadata metadata,
                             Exception exception) {
      long now = System.currentTimeMillis();
      int latency = (int) (now - start);
      stats.report(now, latency, this.bytes);
      if (exception != null) {
        exception.printStackTrace();
        System.exit(1);
      }
    }
  }
}
