/* Copyright (c) 2009 & onwards. MapR Tech, Inc., All rights reserved */

package com.mapr.streams.listener;

import java.io.*;
import java.nio.ByteBuffer;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.common.serialization.ByteArrayDeserializer;
import org.apache.kafka.common.TopicPartition;


public class ListenerPerformance implements Runnable {

  private String streamName;
  private String streamNameToCheck;
  private int numTopics = 2;
  private int numPartitions = 4;
  private int numExpectedMsgs = 100000;
  private int numBatches = 1;

  private boolean verifyKeys = false;
  private boolean keysInOrder = false;
  private boolean allowDuplicateKeys = false;
  private boolean isTracingEnabled = false;
  private boolean printProgress = false;
  private boolean topicSubscriptions = false;
  private String groupId = null;
  private String clientId = null;

  private Hashtable<PartitionInfo, Integer> partitionSeqMap;
  private Hashtable<PartitionInfo, Long> partitionOffsetMap;
  private Hashtable<PartitionInfo, boolean[]> partitionBArrayMap;
  public boolean status;

  public static void usage() {
    System.err.println("ListenerPerformance -path <stream-full-name>");
    System.err.println("     [ -ntopics <num topics> (default: 2) ]");
    System.err.println("     [ -npart <numpartitions per topic> (default: 1) ]");
    System.err.println("     [ -nmsgs <num messages per topicfeed> (default: 100000) ]");
    //System.err.println("       -nbatches <number of batches>");
    //System.err.println("       -verify <true/false>");
    //System.err.println("       -keysinorder <true/false>");
    //System.err.println("       -debug <true/false>");
    //System.err.println("       -allowduplicates <true/false>");
    System.err.println("     [ -group <consumer group id> (default: null) ]");
    //System.err.println("   [   -client <client id>");
    System.err.println("     [ -topicsubscription <true/false> (default: false) ]");
    //System.err.println("       -progress <true/false> ]");
    System.exit(1);
  }

  public static void main(String[] args) throws Exception {
    String streamName = null;
    int numTopics = 2;
    int numPartitions = 1;
    int numExpectedMsgs = 100000;
    int numBatches = 1;
    boolean verifyKeys = true;
    boolean keysInOrder = false;
    boolean allowDuplicateKeys = false;
    boolean isTracingEnabled = false;
    boolean printProgress = true;
    boolean topicSubscriptions = false;
    String groupId = null;
    String clientId = null;

    for (int i = 0; i < args.length; ++i) {
      if (args[i].equals("-path")) {
        i++;
        if (i >= args.length) usage();
        streamName = args[i];
      } else if (args[i].equals("-ntopics")) {
        i++;
        if (i >= args.length) usage();
        numTopics = Integer.parseInt(args[i]);
      } else if (args[i].equals("-npart")) {
        i++;
        if (i >= args.length) usage();
        numPartitions = Integer.parseInt(args[i]);
      } else if (args[i].equals("-nmsgs")) {
        i++;
        if (i >= args.length) usage();
        numExpectedMsgs = Integer.parseInt(args[i]);
      } else if (args[i].equals("-verify")) {
        i++;
        if (i >= args.length) usage();
        verifyKeys = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-keysinorder")) {
        i++;
        if (i >= args.length) usage();
        keysInOrder = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-debug")) {
        i++;
        if (i >= args.length) usage();
        isTracingEnabled = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-allowduplicates")) {
        i++;
        if (i >= args.length) usage();
        allowDuplicateKeys = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-nbatches")) {
        i++;
        if (i >= args.length) usage();
        numBatches = Integer.parseInt(args[i]);
      } else if (args[i].equals("-group")) {
        i++;
        if (i >= args.length) usage();
        groupId = args[i];
      } else if (args[i].equals("-client")) {
        i++;
        if (i >= args.length) usage();
        clientId = args[i];
      } else if (args[i].equals("-topicsubscription")) {
        i++;
        if (i >= args.length) usage();
        topicSubscriptions = Boolean.parseBoolean(args[i]);
      } else if (args[i].equals("-progress")) {
        i++;
        if (i >= args.length) usage();
        printProgress = Boolean.parseBoolean(args[i]);
      } else {
        usage();
      }
    }

    ListenerPerformance lp = new ListenerPerformance(streamName, numTopics,
                                                     numPartitions,
                                                     numExpectedMsgs,
                                                     numBatches,
                                                     verifyKeys, keysInOrder,
                                                     allowDuplicateKeys,
                                                     isTracingEnabled,
                                                     printProgress, groupId,
                                                     clientId,
                                                     topicSubscriptions);
    Thread lt = new Thread(lp);
    lt.start();

    lt.join();
    if (lp.status == false) {
      System.out.println("ListenerPerformance test failed.");
    }
  }

  public ListenerPerformance(String streamName, int numTopics,
                             int numPartitions,
                             int numExpectedMsgs, int numBatches,
                             boolean verifyKeys, boolean keysInOrder, boolean
                             allowDuplicateKeys, boolean isTracingEnabled,
                             boolean printProgress,
                             String groupId,
                             String clientId,
                             boolean topicSubscriptions) {
    this.streamName = streamName;
    this.streamNameToCheck = streamName;
    this.numTopics = numTopics;
    this.numPartitions = numPartitions;
    this.numExpectedMsgs = numExpectedMsgs;
    this.numBatches = numBatches;
    this.verifyKeys = verifyKeys;
    this.keysInOrder = keysInOrder;
    this.allowDuplicateKeys = allowDuplicateKeys;
    this.isTracingEnabled = isTracingEnabled;
    this.printProgress = printProgress;
    this.groupId = groupId;
    this.clientId = clientId;
    this.topicSubscriptions = topicSubscriptions;
  }

  public ListenerPerformance(String streamName,
                             String streamNameToCheck, int numTopics,
                             int numPartitions, int numExpectedMsgs,
                             int numBatches, boolean verifyKeys,
                             boolean keysInOrder, boolean allowDuplicateKeys,
                             boolean isTracingEnabled, boolean printProgress,
                             String groupId,
                             String clientId,
                             boolean topicSubscriptions) {
    this.streamName = streamName;
    this.streamNameToCheck = streamNameToCheck;
    this.numTopics = numTopics;
    this.numPartitions = numPartitions;
    this.numExpectedMsgs = numExpectedMsgs;
    this.numBatches = numBatches;
    this.verifyKeys = verifyKeys;
    this.keysInOrder = keysInOrder;
    this.allowDuplicateKeys = allowDuplicateKeys;
    this.isTracingEnabled = isTracingEnabled;
    this.printProgress = printProgress;
    this.groupId = groupId;
    this.clientId = clientId;
    this.topicSubscriptions = topicSubscriptions;

  }

  public void reportProgress(boolean enabled) {
    this.printProgress = enabled;
  }

  public void run() {
    int pollsWithMissingMsgs = 0;
    status = false;
    try {
      if (streamName == null || streamName.length() == 0) {
        System.err.println("stream name cannot be empty.");
        ListenerPerformance.usage();
      }

      if (keysInOrder) {
        partitionSeqMap = new Hashtable<PartitionInfo, Integer>();
      } else {
        partitionBArrayMap = new Hashtable<PartitionInfo, boolean[]>();
      }

      partitionOffsetMap = new Hashtable<PartitionInfo, Long>();

      Properties props = new Properties();
      props.put("key.deserializer",
          "org.apache.kafka.common.serialization.ByteArrayDeserializer");
      props.put("value.deserializer",
          "org.apache.kafka.common.serialization.ByteArrayDeserializer");
      props.put("auto.offset.reset", "earliest");
      props.put("auto.commit.interval.ms", 1000);
      props.put("fetch.min.bytes", 1);
      props.put("max.partition.fetch.bytes", 64 * 1024);

      if (groupId != null) {
        props.put("group.id", groupId);
      }
      if (clientId != null) {
        props.put("client.id", clientId);
      }

      ByteArrayDeserializer keyD = new ByteArrayDeserializer();
      ByteArrayDeserializer valueD = new ByteArrayDeserializer();
      KafkaConsumer listener = new KafkaConsumer<byte[], byte[]>(props, keyD, valueD);
      ConsumerRebalanceListener cb = new RebCb();

      List<TopicPartition> partitions = new ArrayList<TopicPartition>();
      List<String> topicList = new ArrayList<String>();

      for (int i = 0; i < numTopics; i++) {
        String streamTopicName = streamName + ":topic" + i;
        topicList.add(streamTopicName);
        for (int j = 0; j < numPartitions; j++) {
          partitions.add(new TopicPartition(streamTopicName, j));
          if (isTracingEnabled)
            System.out.println("Creating a subscription for " + streamTopicName + " feed:" + j);
        }
      }

      if (topicSubscriptions) {
        listener.subscribe(topicList, cb);
      } else {
        listener.assign(partitions);
      }

      System.out.println("Subscription successful");

      PerfStats stats = new PerfStats();
      int totalNumMsgs = 0;
      int totalNumMsgsExpected = numTopics*numPartitions*numExpectedMsgs*numBatches;
      while (totalNumMsgs < totalNumMsgsExpected) {
        if (isTracingEnabled)
          System.out.println("pollsWithMissingMsgs "+ pollsWithMissingMsgs);

        ConsumerRecords<byte[], byte[]> recs = listener.poll(1000);
        if (recs.count() == 0) {
          ++pollsWithMissingMsgs;
        } else {
          pollsWithMissingMsgs = 0;
        }
        totalNumMsgs += VerifyAndAddStats(recs, stats);
      }
      if (verifyKeys) {
        VerifyPollingEnd();
      }
      stats.printReport();

      Set<TopicPartition> subscribed = listener.assignment();

      while (subscribed.size() == 0) {
        try {
          System.out.println("Assignment size " + subscribed.size());
          Thread.sleep(1000);
        } catch (Exception e) {
        }
        subscribed = listener.assignment();
      }

      if (isTracingEnabled) {
        for (TopicPartition p : subscribed) {
          System.out.println("Subscribed to " + p.topic() + " partition:" +
                             p.partition());
        }
      }

      if (verifyKeys) {
        System.out.println("committing offsets");
        listener.commitSync();

        System.out.println("Subscription size " + subscribed.size());
        for (TopicPartition p : subscribed) {
          long offset = listener.committed(p).offset();
          String[] topicNameParts = p.topic().split(":");
          PartitionInfo pinfo = new PartitionInfo(streamNameToCheck,
                                                  topicNameParts[1], p.partition());
          System.out.println("partition check: " + topicNameParts[0] + " " + topicNameParts[1]);
          Long mappedoffset = partitionOffsetMap.get(pinfo);
          if ((mappedoffset + 1) != offset) {
            System.out.println("Commit offset for  " + p.topic() + " partition " +
                               p.partition() + ":" + offset + " expected:" +
                               (mappedoffset));
            throw new IOException("unexpected commit offset");
          }
        }
      }

      listener.close();
      status = true;
    } catch (IOException e) {
      System.out.println(e);
    }
  }

  private int VerifyAndAddStats(ConsumerRecords<byte[], byte[]> recs, PerfStats stats) throws IOException {
    Iterator<ConsumerRecord<byte[], byte[]>> iter = recs.iterator();
    long numBytes = 0;
    long maxLag = 0;
    long totalLag = 0;
    int numMsgs = 0;
    long now = System.currentTimeMillis();
    while (iter.hasNext()) {
      ConsumerRecord<byte[], byte[]> rec = iter.next();
      if (isTracingEnabled)
        System.out.println(rec);
      byte[] key = null;
      byte[] value = null;
      long offset;
      try {
        key = rec.key();
        value = rec.value();
        ByteBuffer buf = ByteBuffer.wrap(value);
        long producerTime = buf.getLong();
        long lag = now - producerTime;
        if (isTracingEnabled) {
          String keyStr = new String(key, "UTF-8");
          System.out.println("Producer Time " + producerTime + " lag " + lag + " for key " + keyStr);
        }
        if (lag > maxLag) maxLag = lag;
        totalLag += lag;
        offset = rec.offset();
      } catch (Exception e) {
        throw new IOException("ConsumerRecord Exception");
      }
      numMsgs++;
      numBytes += key.length;
      numBytes += value.length;
      if (verifyKeys) {
        String keyStr = new String(key, "UTF-8");
        String[] tokens = keyStr.split(":");
        if (tokens.length != 4) {
          throw new IOException("Key " + keyStr + " not of correct format");
        }

        int partition = Integer.parseInt(tokens[2]);
        int seq = Integer.parseInt(tokens[3]);
        if (isTracingEnabled)
          System.out.println("Key " + keyStr + " ntokens " + tokens.length + " bytes " + (key.length + value.length) + " offset " + offset);
        if (!tokens[0].equals(streamNameToCheck)) {
          throw new IOException("streamName in key " + tokens[0] + " not same as " + streamName);
        }
        String recTopic = rec.topic();
        String[] recTopicTokens = recTopic.split(":");
        if ((recTopicTokens.length != 2) ||
            !tokens[1].equals(recTopicTokens[1])) {
          throw new IOException("topic in key " + tokens[1] + " mismatched. Expected " + recTopicTokens[1]);
        }
        if (partition != rec.partition()) {
          throw new IOException("partition in key " + partition + " mismatched. Expected " + rec.partition());
        }
        if (keysInOrder) {
          VerifyPollingOrderedKey(tokens[0], tokens[1], partition, seq, offset);
        } else {
          VerifyPollingUnorderedKey(tokens[0], tokens[1], partition, seq, offset);
        }
      }
      if (isTracingEnabled)
        System.out.println("Value " + new String(value, "UTF-8"));
    }
    stats.report(numBytes, recs.count(), maxLag, totalLag, now);
    return numMsgs;
  }

  private void VerifyPollingOrderedKey(String stName, String tpName, int
      partition, int seq, long offset)
      throws IOException
  {
    PartitionInfo pinfo = new PartitionInfo(stName, tpName, partition);
    Integer mappedSeq = partitionSeqMap.get(pinfo);
    int expSeq = 0;
    if (mappedSeq != null) {
      expSeq = mappedSeq + 1;
    }
    if (seq != expSeq) {
      throw new IOException("Current Seq " + seq + " for Stream " + stName +
                            " Topic " + tpName + " partition " + partition +
                            " mismatched. Expected " + expSeq);
    }
    partitionSeqMap.put(pinfo, seq);
    
    Long mappedoffset = partitionOffsetMap.get(pinfo);
    if (mappedoffset != null) {
      if (mappedoffset > offset) {
        throw new IOException("Got out of order offsets");
      }
    }
    partitionOffsetMap.put(pinfo, offset);
  }

  private void VerifyPollingUnorderedKey(String stName, String tpName, int
      partition, int seq, long offset)
      throws IOException
  {
    PartitionInfo pinfo = new PartitionInfo(stName, tpName, partition);
    boolean[] bitArray = partitionBArrayMap.get(pinfo);
    if (bitArray == null) {
      //TODO(NARENDRA): use BitSet instead.
      bitArray = new boolean[numExpectedMsgs*numBatches];
      partitionBArrayMap.put(pinfo, bitArray);
    }
    if (bitArray[seq] && !allowDuplicateKeys) {
      throw new IOException("Duplicate key for Stream " + stName + " Topic " + tpName + " partition " + partition + ". seq is " + seq);
    }
    bitArray[seq] = true;
    
    Long mappedoffset = partitionOffsetMap.get(pinfo);
    if (mappedoffset != null) {
      if (mappedoffset > offset) {
        throw new IOException("Got out of order offsets");
      }
    }
    partitionOffsetMap.put(pinfo, offset);
  }

  private void VerifyPollingEnd() throws IOException
  {
    int totalPartitions = numTopics * numPartitions;
    if (keysInOrder) {
      int lastExpectedSeq = (numExpectedMsgs*numBatches) - 1;
      if (partitionSeqMap.size() != totalPartitions) {
        throw new IOException("Total entries in hashmap " + partitionSeqMap.size() + ", expected " + totalPartitions);
      }
      Set<Map.Entry<PartitionInfo, Integer>> set = partitionSeqMap.entrySet();
      for (Map.Entry<PartitionInfo, Integer> entry : set) {
        if (entry.getValue() != lastExpectedSeq) {
          throw new IOException(entry.getKey() + ", Last seq received " + entry.getValue() + ", expected " + lastExpectedSeq);
        }
      }
    } else {
      if (partitionBArrayMap.size() != totalPartitions) {
        throw new IOException("Total entries in hashmap " + partitionBArrayMap.size() + ", expected " + totalPartitions);
      }
      Set<Map.Entry<PartitionInfo, boolean[]>> set = partitionBArrayMap.entrySet();
      for (Map.Entry<PartitionInfo, boolean[]> entry : set) {
        boolean[] barray = entry.getValue();
        for (int i = 0; i < (numExpectedMsgs*numBatches); i++) {
          if (!barray[i]) {
            throw new IOException("Message with seq " + i + " missing");
          }
        }
      }
    }
  }

  private static final class RebCb implements ConsumerRebalanceListener {

    RebCb() { }

    public void
      onPartitionsAssigned(Collection<TopicPartition> partitions) { }

    public void
      onPartitionsRevoked(Collection<TopicPartition> partitions) {}
  }

  private final class PerfStats {
    private long startTime;
    private long endTime;
    private long totalBytes;
    private long totalMsgs;
    private long maxLag;
    private long totalLag;
    private long lastProgressTime;
    private long lastProgressMsgs;
    private int nsecs;

    public PerfStats() {
      this.endTime = -1;
      this.totalBytes = 0;
      this.totalMsgs = 0;
      this.maxLag = 0;
      this.totalLag = 0;
      this.startTime = System.currentTimeMillis();
      this.lastProgressTime = this.startTime;
      this.lastProgressMsgs = 0;
      this.nsecs = 0;
    }

    public synchronized void report(long bytes, long numMsgs,
                                    long maxLag, long totalLag, long now) {
      this.totalLag += totalLag;
      this.totalBytes += bytes;
      this.totalMsgs += numMsgs;
      if (maxLag > this.maxLag) {
        this.maxLag = maxLag;
      }

      if (printProgress && (now - this.lastProgressTime >= 1000)) {
        DateFormat df = new SimpleDateFormat("dd/MM/yy HH:mm:ss");
        Date date = new Date();

        System.out.printf("%s %4d secs %9d msgs %8d msgs/s maxlag %d %n",
                          df.format(date),
                          nsecs,
                          this.totalMsgs,
                          this.totalMsgs - this.lastProgressMsgs,
                          this.maxLag);

        this.lastProgressMsgs = this.totalMsgs;
        this.lastProgressTime = now;
        this.nsecs++;
      }
    }

    public synchronized void printReport() {
      this.endTime = System.currentTimeMillis();
      System.out.println("Total time (ms): " + (this.endTime - this.startTime));
      System.out.println("Total bytes received: " + this.totalBytes);
      System.out.println("Total messages received: " + this.totalMsgs);
      long bytesInKb = totalBytes / 1024;
      System.out.println("Average nKBs/sec: " + bytesInKb * 1.0/(this.endTime-this.startTime)*1000.0);
      System.out.println("Average nMsgs/sec: " + this.totalMsgs * 1.0/(this.endTime-this.startTime)*1000.0);
      System.out.println("Average lag in ms: " + this.totalLag / this.totalMsgs);
      System.out.println("Maximum lag in ms: " + this.maxLag);
    }
  }

  private class PartitionInfo {
    private final String streamName;
    private final String topicName;
    private final int partitionId;

    public PartitionInfo(String streamName, String topicName, int partitionId) {
      this.streamName = streamName;
      this.topicName = topicName;
      this.partitionId = partitionId;
    }

    public boolean equals(Object anObject) {
      if (this == anObject) {
        return true;
      }
      if (anObject instanceof PartitionInfo) {
        PartitionInfo pinfo = (PartitionInfo)anObject;
        if (streamName.equals(pinfo.streamName()) && topicName.equals(pinfo.topicName()) && (partitionId == pinfo.partitionId()))
          return true;
      }
      return false;
    }

    public int hashCode() {
      return topicName.hashCode() + partitionId;
    }

    @Override
    public String toString() {
      return "Stream: " + streamName + " Topic: " + topicName + " Partition: " + partitionId;
    }

    public String streamName() { return streamName; }
    public String topicName() { return topicName; }
    public int partitionId() { return partitionId; }
  }
}

