/* Copyright (c) 2009 & onwards. MapR Tech, Inc., All rights reserved */

package com.mapr.streams.producer;

import java.lang.System;
import java.io.*;
import java.nio.ByteBuffer;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.Properties;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.ProducerConfig;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;

public class ProducerPerformance {

  public static String streamName;
  public static String[] topicNames;
  public static int numTopics = 2;
  public static int numSlowTopics = 0;
  public static int numPartitions = 1;
  public static int numMsgsPerPartition = 100000;
  public static int numBatches = 1;
  public static long batchSleepMs = 10*1000;
  public static boolean multipleFlushers = true;
  public static boolean printProgress = true;
  public static boolean needVerify = true;
  public static boolean roundRobin = false;
  public static boolean hashKey = false;
  public static int totalNumMsgs;
  public static KafkaProducer producer;
  public static int msgValueLength = 200;
  public static long producerPoolSz = 32*1024*1024;
  public static int slowToNormalTopicRatio = 1000;
  public static boolean checkLag = true;
  public static boolean ignoreErr = false;
  public static long metadataRefreshMs = 5*60*1000;

  private static int numMsgsWithTS = 1000;
  private static byte[] inlineValue = null;
  private static byte[] inlineValueDefault = new byte[msgValueLength];

  public static void usage() {
    System.err.println("ProducerPerformance -path <stream-full-name>");
    System.err.println("     [ -ntopics <num topics> (default: 2) ]");
    System.err.println("     [ -npart <numpartitions per topic> (default: 1) ]");
    System.err.println("     [ -nmsgs <num messages per topicfeed> (default: 100000) ]");
    //System.err.println("       -multiflush <enable multiple flushers>");
    //System.err.println("       -nbatches <number of batches>");
    //System.err.println("       -batchsleepms <milliseconds to sleep between batches>");
    System.err.println("     [ -msgsz <msg value size> (default: 200) ]");
    //System.err.println("       -poolsz <pool size>");
    //System.err.println("       -verify <true/false>");
    System.err.println("     [ -rr <round robin true/false> (default: false) ]");
    //System.err.println("       -checklag <true/false>");
    //System.err.println("       -ignoreerror <true/false>");
    System.err.println("     [ -hashkey <true/false> (default: false) ]");
    System.exit(1);
  }

  public static void main(String[] args) throws IOException {
    for (int i = 0; i < args.length; ++i) {
      if (args[i].equals("-path")) {
        i++;
        if (i >= args.length) usage();
        streamName = args[i];
      } else if (args[i].equals("-nmsgs")) {
        i++;
        if (i >= args.length) usage();
        numMsgsPerPartition = Integer.parseInt(args[i]);
      } else if (args[i].equals("-ntopics")) {
        i++;
        if (i >= args.length) usage();
        numTopics = Integer.parseInt(args[i]);
      } else if (args[i].equals("-npart")) {
        i++;
        if (i >= args.length) usage();
        numPartitions = Integer.parseInt(args[i]);
      } else if (args[i].equals("-multiflush")) {
        i++;
        if (i >= args.length) usage();
        multipleFlushers = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-nbatches")) {
        i++;
        if (i >= args.length) usage();
        numBatches = Integer.parseInt(args[i]);
      } else if (args[i].equals("-batchsleepms")) {
        i++;
        if (i >= args.length) usage();
        batchSleepMs = Long.parseLong(args[i]);
      } else if (args[i].equals("-progress")) {
        i++;
        if (i >= args.length) usage();
        printProgress = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-msgsz")) {
        i++;
        if (i >= args.length) usage();
        msgValueLength = Integer.parseInt(args[i]);
      } else if (args[i].equals("-poolsz")) {
        i++;
        if (i >= args.length) usage();
        producerPoolSz = Long.parseLong(args[i]);
      } else if (args[i].equals("-nslowtopics")) {
        i++;
        if (i >= args.length) usage();
        numSlowTopics = Integer.parseInt(args[i]);
      } else if (args[i].equals("-verify")) {
        i++;
        if (i >= args.length) usage();
        needVerify = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-rr")) {
        i++;
        if (i >= args.length) usage();
        roundRobin = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-hashkey")) {
        i++;
        if (i >= args.length) usage();
        hashKey = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-checklag")) {
        i++;
        if (i >= args.length) usage();
        checkLag = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-ignoreerror")) {
        i++;
        if (i >= args.length) usage();
        ignoreErr = Boolean.parseBoolean(args[i]);
      } else {
        usage();
      }
    }

    runTest();
  }

  public static boolean runStressTest(String stream,
                                      int nmsgs,
                                      int ntopics,
                                      int npartitions,
                                      int nbatches,
                                      long sleepms,
                                      boolean verify,
                                      boolean mflushers,
                                      boolean progress,
                                      boolean roundrobin,
                                      boolean hashkey) throws IOException {
    streamName = stream;
    numMsgsPerPartition = nmsgs;
    numTopics = ntopics;
    numSlowTopics = 0;
    needVerify = verify;
    numPartitions = npartitions;
    multipleFlushers = mflushers;
    numBatches = nbatches;
    batchSleepMs = sleepms;
    msgValueLength = 200;
    printProgress = progress;
    roundRobin = roundrobin;
    hashKey = hashkey;
    return runTest();
  }

  public static boolean runBasicTest(String stream,
                                     int nmsgs,
                                     int ntopics,
                                     int slowTopics,
                                     int npartitions,
                                     int nbatches,
                                     long sleepms,
                                     boolean mflushers,
                                     boolean progress,
                                     int msgsz,
                                     long poolsz,
                                     long metadataMillis) throws IOException {
    streamName = stream;
    numMsgsPerPartition = nmsgs;
    numTopics = ntopics;
    numSlowTopics = slowTopics;
    needVerify = true;
    numPartitions = npartitions;
    multipleFlushers = mflushers;
    numBatches = nbatches;
    batchSleepMs = sleepms;
    printProgress = progress;
    msgValueLength = msgsz;
    producerPoolSz = poolsz;
    metadataRefreshMs = metadataMillis;
    return runTest();
  }

  public static boolean runTest() {
    if (streamName == null || streamName.length() == 0) {
      System.err.println("stream name cannot be empty.");
      usage();
    }

    if (numPartitions <= 0) {
      System.err.println("num partitions cannot be negative or zero.");
      usage();
    }

    if (numTopics <= 0) {
      System.err.println("num topics cannot be negative or zero.");
    }

    totalNumMsgs = numMsgsPerPartition * numPartitions * numTopics * numBatches;
    totalNumMsgs += (numMsgsPerPartition / slowToNormalTopicRatio) * numPartitions * numSlowTopics * numBatches;

    // Now get all the topic feeds to send messages to.
    topicNames = new String[numTopics + numSlowTopics];
    int index = 0;
    for(int i = 0; i < numTopics + numSlowTopics; ++i) {
      String topicName = streamName+":topic" + i;
      topicNames[i] = topicName;
    }

    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getParallelFlushersPerPartition(), new Boolean(multipleFlushers).toString());
    props.put(cdef.getBufferTime(), 3*1000);
    props.put(cdef.getMetadataMaxAge(), metadataRefreshMs);
    props.put(cdef.getBufferMemory(), producerPoolSz);

    long sleepTime = 0;

    byte[] value = null;
    byte[] key = null;
    int keySz = 0;
    if (!needVerify) {
      // send the same data with every msg
      key = new byte[30];
      keySz = key.length;
    }

    producer = new KafkaProducer<byte[], byte[]>(props);
    PerfStats stats = new PerfStats(totalNumMsgs);

    for (int batchIdx = 0; batchIdx < numBatches; ++batchIdx) {
      int msgIdx = batchIdx*numMsgsPerPartition;
      for (int i = 0; i < numMsgsPerPartition; ++i) {
        for (int topicIdx = 0; topicIdx < numTopics + numSlowTopics; ++topicIdx) {

          if ( (topicIdx >= numTopics) && (i % slowToNormalTopicRatio != 0) ) {
            continue;
          }

          for (int partIdx = 0; partIdx < numPartitions; ++partIdx) {
            long curTime = -1;
            if (i % numMsgsWithTS == 0) {
              curTime = System.currentTimeMillis();
            }
            if (needVerify || hashKey) {
              int keyIdx = msgIdx + i;
              key = (topicNames[topicIdx]+":"+partIdx+":"+keyIdx).getBytes();
              keySz = key.length;

              if (needVerify) {
                int bufLen = msgValueLength;
                if (bufLen < (keySz + 10))
                  bufLen = keySz + 10;
                value = new byte[bufLen];
                ByteBuffer valueBuffer = ByteBuffer.wrap(value);
                if (curTime == -1)
                  curTime = System.currentTimeMillis();
                valueBuffer.putLong(curTime);
                valueBuffer.put(key);
                while (true) {
                  if (valueBuffer.position() > (bufLen - 2))
                    break;
                  valueBuffer.putChar('a');
                }
              } else {
                value = GetValueWithLag(curTime);
              }
            } else {
              value = GetValueWithLag(curTime);
            }
            long sendBytes = keySz + msgValueLength;
            Callback cb = new PerfCallback(curTime, sendBytes, stats);

            ProducerRecord<byte[], byte[]> record;
            // round robin or hashing assumes that we are not verifying keys, since the application does
            // not directly control where the messages land.
            if (roundRobin) {
              assert(!needVerify);
              record = new ProducerRecord<byte[], byte[]>(topicNames[topicIdx], null, value);
            } else if (hashKey) {
              assert(!needVerify);
              record = new ProducerRecord<byte[], byte[]>(topicNames[topicIdx], key, value);
            } else {
              // For this test, we will specify the partition in partIdx.
              record = new ProducerRecord<byte[], byte[]>(topicNames[topicIdx], partIdx, key, value);
            }
            producer.send(record, cb);
          }
        }
      }

      if (batchIdx < numBatches - 1) {
        long sleepStart = System.currentTimeMillis();
        try {
          Thread.sleep(batchSleepMs);
        } catch (InterruptedException e) {}
        sleepTime += (System.currentTimeMillis() - sleepStart);
      }
    }

    producer.flush();
    producer.close();
    boolean verify = stats.checkAndVerify();
    stats.printReport(numTopics, numPartitions, sleepTime);

    return verify;
  }

  private static final class PerfStats {
    private long startTime;
    private long endTime;
    private long minLatency;
    private long maxLatency;
    private long totalLatency;
    private long totalBytes;
    private long msgCount;
    private long totalMsgsToExpect;
    private long lastProgressTime;
    private long lastProgressMsgs;
    private int nsecs;

    public PerfStats(int totalMsgs) {
      this.totalMsgsToExpect = totalMsgs;
      this.minLatency = 999999;
      this.maxLatency = 0;
      this.totalBytes = 0;
      this.msgCount = 0;
      this.totalLatency = 0;
      this.endTime = -1;
      this.startTime = System.currentTimeMillis();
      this.lastProgressTime = this.startTime;
      this.lastProgressMsgs = 0;
      this.nsecs = 0;
    }

    public synchronized boolean checkAndVerify() {
      if (this.totalMsgsToExpect != this.msgCount) {
        System.out.println("***** Verification failed! " +
                           this.totalMsgsToExpect + " != " +
                           this.msgCount + " *****");
        return false;
      }
      return true;
    }

    public synchronized void report(long now, long latency, long bytes) {
      this.minLatency = Math.min(this.minLatency, latency);
      this.maxLatency = Math.max(this.maxLatency, latency);
      this.totalBytes += bytes;
      this.msgCount++;
      this.totalLatency += latency;

      if (printProgress && (now - this.lastProgressTime >= 1000)) {
        DateFormat df = new SimpleDateFormat("dd/MM/yy HH:mm:ss");
        Date date = new Date();

        System.out.printf("%s %4d secs %9d msgs %8d msgs/s %n",
                          df.format(date),
                          nsecs,
                          this.msgCount,
                          this.msgCount - this.lastProgressMsgs); 

        this.lastProgressMsgs = this.msgCount;
        this.lastProgressTime = now;
        this.nsecs++;
      }

      if (this.msgCount == this.totalMsgsToExpect) {
        this.endTime = System.currentTimeMillis();
      }
    }

    public synchronized void printReport(int numTopics, int numPartitions, long sleepTime) {
      System.out.println("***** Producer Report Start *****");
      System.out.println("nTopics/nPartitions: " + numTopics + "/" + numPartitions);
      System.out.println("Expected nMsgs: " + this.totalMsgsToExpect);
      System.out.println("Callback nMsgs: " + this.msgCount);
      if (this.endTime == -1)
        this.endTime = System.currentTimeMillis();

      long elapsedTime = this.endTime - this.startTime;
      long runTime = this.endTime - this.startTime - sleepTime;
      System.out.println("Total time (ms): " + elapsedTime);
      System.out.println("Total sleep time (ms): " + sleepTime);
      System.out.println("Total run time (ms): " + runTime);
      System.out.println("Total bytes sent: " + this.totalBytes);
      System.out.println("Min/Max latency (ms): " + this.minLatency + "/" + this.maxLatency);
      System.out.println("Average latency (ms): " + (this.totalLatency/this.msgCount));
      System.out.println("Average nMsgs/sec: " + this.msgCount*1.0/runTime*1000.0);
      System.out.println("Average nKBs/sec: " + this.totalBytes/runTime*1000.0/1024.0);
      System.out.println("***** Producer Report End *****");
    }
  }

  private static final class PerfCallback implements Callback {

    private final long start;
    private final long bytes;
    private static boolean printStack = true;
    private PerfStats stats;

    public PerfCallback(long start, long bytes, PerfStats stats) {
      this.start = start;
      this.bytes = bytes;
      this.stats = stats;
    }

    public void onCompletion(RecordMetadata metadata,
                             Exception exception) {
      long now = System.currentTimeMillis();
      int latency = (int) (now - start);
      stats.report(now, latency, this.bytes);
      if (exception != null) {
        if (printStack) {
          exception.printStackTrace();
          printStack = false;
        }
        if (!ignoreErr)
          System.exit(1);
      } else {
        printStack = true;
      }
    }
  }

  private static byte[] GetValueWithLag(long curTime) {
    if (checkLag && curTime != -1) {
      inlineValue = new byte[msgValueLength];
      ByteBuffer valueBuffer = ByteBuffer.wrap(inlineValue);
      valueBuffer.putLong(curTime);
    } else {
      inlineValue = inlineValueDefault;
    }
    return inlineValue;
  }
}
