/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */

package com.mapr.streams.producer;

import java.lang.System;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Properties;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.producer.ProducerRecord;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;

public class Producer {

  public static String streamName;
  public static int numStreams = 2;
  public static int numTopics = 2;
  public static int numSlowTopics = 2;
  public static int numPartitions = 4;
  public static int numMsgsPerPartition = 100000;
  public static boolean multipleFlushers = true;
  public static KafkaProducer producer;
  public static final int MSG_VALUE_LENGTH = 200;
  public static boolean printStats = true;

  public static void usage() {
    System.err.println("Producer -path <stream-intialname> [-nstreams <num streams> -ntopics <num topics> -nslowtopics <num topics> -npart <numpartitions per topic> -nmsgs <num messages per topicfeed>]  -multiflush <enable multiple flushers>]");
    System.exit(1);
  }

  public static void main(String[] args) throws IOException {
    for (int i = 0; i < args.length; ++i) {
      if (args[i].equals("-path")) {
        i++;
        if (i >= args.length) usage();
        streamName = args[i];
      } else if (args[i].equals("-nstreams")) {
        i++;
        if (i >= args.length) usage();
        numStreams = Integer.parseInt(args[i]);
      } else if (args[i].equals("-nmsgs")) {
        i++;
        if (i >= args.length) usage();
        numMsgsPerPartition = Integer.parseInt(args[i]);
      } else if (args[i].equals("-ntopics")) {
        i++;
        if (i >= args.length) usage();
        numTopics = Integer.parseInt(args[i]);
      } else if (args[i].equals("-nslowtopics")) {
        i++;
        if (i >= args.length) usage();
        numSlowTopics = Integer.parseInt(args[i]);
      } else if (args[i].equals("-npart")) {
        i++;
        if (i >= args.length) usage();
        numPartitions = Integer.parseInt(args[i]);
      } else if (args[i].equals("-multiflush")) {
        i++;
        if (i >= args.length) usage();
        multipleFlushers = Boolean.parseBoolean(args[i]);
      } else {
        usage();
      }
    }

    run();
  }

  public static boolean runTest(String path, int nstreams, int ntopics, int nslowtopics, int npart, int nmsgs, boolean mflushers) {
    streamName = path;
    numStreams = nstreams;
    numMsgsPerPartition = nmsgs;
    numTopics = ntopics;
    numSlowTopics = nslowtopics;
    numPartitions = npart;
    printStats = false;
    multipleFlushers = mflushers;
    return run();
  }

  public static boolean run() {
    int totalNumMsgs = 0;
    int totalNumTopics = numTopics + numSlowTopics;
    String[] topicNames;
    if (streamName == null || streamName.length() == 0) {
      System.err.println("stream name cannot be empty.");
      usage();
    }

    if (numPartitions <= 0) {
      System.err.println("num partitions cannot be negative or zero.");
      usage();
    }

    if (numTopics <= 0) {
      System.err.println("num topics cannot be negative or zero.");
    }

    totalNumMsgs = numMsgsPerPartition * numPartitions * numTopics * numStreams;
    //messages for slow topic 1/1000th of other topics
    totalNumMsgs += numMsgsPerPartition/1000 * numPartitions * numSlowTopics * numStreams;

    // Now get all the topic feeds to send messages to.
    topicNames = new String[totalNumTopics * numStreams];
    for (int s = 0; s < numStreams; ++s) {
      for (int i = 0; i < totalNumTopics; ++i) {
        String topicName = streamName + s +":topic" + i;
        topicNames[s * totalNumTopics + i] = topicName;
      }
    }

    Properties props = new Properties();
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getParallelFlushersPerPartition(), multipleFlushers);  // Do not allow multiple flushers
    props.put(cdef.getBufferTime(), 1000);  // For tests, just use a small number so that things run more frequently
    props.put(cdef.getMetadataMaxAge(), 1000);


    producer = new KafkaProducer<byte[], byte[]>(props);
    PerfStats stats = new PerfStats(totalNumMsgs);
    for (int sIdx = 0; sIdx < numStreams; ++sIdx) {
      for (int topicIdx = 0; topicIdx < totalNumTopics; ++topicIdx) {
        for (int partIdx = 0; partIdx < numPartitions; ++partIdx) {
          int msgIdx = 0;
          for (int i = 0; i < numMsgsPerPartition; ++i) {
            if (topicIdx >= numTopics) {
              //For slow topics create only 1/1000th messages.
              if (i % 1000 != 0)
                continue;
            }
            String topicName = topicNames[sIdx * totalNumTopics + topicIdx];

            byte[] key = (topicName+":"+partIdx+":"+msgIdx).getBytes();
            int keySz = key.length;

            byte[] value = new byte[MSG_VALUE_LENGTH];
            ByteBuffer valueBuffer = ByteBuffer.wrap(value);
            for (int j = 0; j < MSG_VALUE_LENGTH/keySz; ++j) {
              valueBuffer.put(key);
            }

            long sendStart = System.currentTimeMillis();
            long sendBytes = keySz + MSG_VALUE_LENGTH;
            Callback cb = new PerfCallback(sendStart, sendBytes, stats);

            ProducerRecord<byte[], byte[]> record =
              new ProducerRecord<byte[], byte[]>(topicName, partIdx, key, value);
            producer.send(record, cb);
            ++msgIdx;
          }
        }
      }
    }

    producer.flush();
    boolean verify = stats.checkAndVerify();

    producer.close();
    if (printStats)
      stats.printReport();
    return verify;
  }

  private static final class PerfStats {
    private long startTime;
    private long endTime;
    private long minLatency;
    private long maxLatency;
    private long totalLatency;
    private long totalBytes;
    private long msgCount;
    private long totalMsgsToExpect;

    public PerfStats(int totalMsgs) {
      this.totalMsgsToExpect = totalMsgs;
      this.minLatency = 999999;
      this.maxLatency = 0;
      this.totalBytes = 0;
      this.msgCount = 0;
      this.totalLatency = 0;
      this.endTime = -1;
      this.startTime = System.currentTimeMillis();
    }

    public synchronized boolean checkAndVerify() {
      if (this.totalMsgsToExpect != this.msgCount) {
        System.out.println("***** Verification failed! *****");
        return false;
      }
      return true;
    }

    public synchronized void report(long latency, long bytes) {
      this.minLatency = Math.min(this.minLatency, latency);
      this.maxLatency = Math.max(this.maxLatency, latency);
      this.totalBytes += bytes;
      this.msgCount++;
      this.totalLatency += latency;

      if (this.msgCount == this.totalMsgsToExpect)
        this.endTime = System.currentTimeMillis();
    }

    public synchronized void printReport() {
      System.out.println("Expected nMsgs: " + this.totalMsgsToExpect);
      System.out.println("Callback nMsgs: " + this.msgCount);
      if (this.endTime == -1)
        this.endTime = System.currentTimeMillis();
      System.out.println("Total time (ms): " + (this.endTime - this.startTime));
      System.out.println("Total bytes sent: " + this.totalBytes);
      System.out.println("Min/Max latency (ms): " + this.minLatency + "/" + this.maxLatency);
      System.out.println("Average latency (ms): " + (this.totalLatency/this.msgCount));
      System.out.println("Average nMsgs/sec: " + this.msgCount*1.0/(this.endTime-this.startTime)*1000.0);
      System.out.println("Average nKBs/sec: " + this.totalBytes/(this.endTime-this.startTime)*1000.0/1024.0);
    }
  }

  private static final class PerfCallback implements Callback {

    private final long start;
    private final long bytes;
    private PerfStats stats;

    public PerfCallback(long start, long bytes, PerfStats stats) {
      this.start = start;
      this.bytes = bytes;
      this.stats = stats;
    }

    public void onCompletion(RecordMetadata metadata,
                             Exception exception) {
      long now = System.currentTimeMillis();
      int latency = (int) (now - start);
      stats.report(latency, this.bytes);
      if (exception != null){
        exception.printStackTrace();
        System.exit(1);
      }
    }
  }
}
