/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.streams.tests.listener;

import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.concurrent.Future;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Iterator;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.hadoop.conf.Configuration;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.common.TopicPartition;

import com.mapr.tests.BaseTest;
import com.mapr.tests.annotations.ClusterTest;

import com.mapr.streams.Admin;
import com.mapr.streams.Streams;
import com.mapr.streams.StreamDescriptor;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;


@Category(ClusterTest.class)
public class OffsetTest extends BaseTest {
  private static final Logger _logger = LoggerFactory.getLogger(OffsetTest.class);
  private static final String STREAM = "/jtest-" + OffsetTest.class.getSimpleName();
  private static Admin madmin;
  private static KafkaProducer producer;
  private static KafkaConsumer consumer;
  private static final int numParts = 1;

  private static Properties GetProducerProps() throws Exception {
    Properties props = new Properties();
    props.put("key.serializer",
      "org.apache.kafka.common.serialization.StringSerializer");
    props.put("value.serializer",
      "org.apache.kafka.common.serialization.StringSerializer");
    props.put("streams.parallel.flushers.per.partition", false);
  
    return props;
  }

  private static Properties GetConsumerProps() throws Exception {
    Properties props = new Properties();
    props.put("key.deserializer",
      "org.apache.kafka.common.serialization.StringDeserializer");
    props.put("value.deserializer",
      "org.apache.kafka.common.serialization.StringDeserializer");
    props.put("enable.auto.commit", false);

    return props;
  }

  private static void CreateStream(String streamName) throws Exception {
    //Cleanup all stale streams
    try {
      madmin.deleteStream(streamName);
    } catch (Exception e) {}
  
    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(numParts);
    madmin.createStream(streamName, sdesc);
  }

  @BeforeClass
  public static void setupTest() throws Exception {
    final Configuration conf = new Configuration();
    madmin = Streams.newAdmin(conf);

    CreateStream(STREAM);

    // Create a producer
    producer = new KafkaProducer<String, String>(GetProducerProps());


    // Create a consumer
    consumer = new KafkaConsumer<byte[], byte[]>(GetConsumerProps());
  }

  @AfterClass
  public static void cleanupTest() throws Exception {
    producer.close();
    consumer.close();
    madmin.deleteStream(STREAM);
  }

  /*
   * Produce 'nMsgs' each of size 'msgSz'. Add a "terminator" msg at the end.
   * Returns the offsets of all the produced msgs.
   */
  public List<Long>
  produceMsgs(String topicName, int msgSz, int nMsgs) throws Exception {
    // Prepare msg of reqd size
    StringBuffer outputBuffer = new StringBuffer(msgSz);
    for (int i = 0; i < msgSz; i++)
      outputBuffer.append("R");
    String msgValue = outputBuffer.toString();

    // Publish the msgs
    ProducerRecord<String, String> rec;
    List<Future<RecordMetadata>> futureList =
      new ArrayList<Future<RecordMetadata>>();
    for (int i = 0; i < nMsgs; i++) {
      Future<RecordMetadata> future;

      rec = new ProducerRecord<String, String>(topicName, //0,
                                               Integer.toString(i),
                                               msgValue);
      future = producer.send(rec);
      futureList.add(future);
    }
    producer.flush();

    rec = new ProducerRecord<String, String>(topicName, "", "terminator");
    producer.send(rec);
    producer.flush();

    // remember the offsets
    List<Long> offsetsList = new ArrayList<Long>(nMsgs);
    int idx = 0;
    for (Future<RecordMetadata> future : futureList) {
      long offset = future.get().offset();
      //_logger.error("offset for msg " + idx + ":" + offset);
      offsetsList.add(offset);
      idx++;
    }

    _logger.info("produced " + nMsgs + " msgs, each of sz " + msgSz);
    return offsetsList;
  }

  // Consume all msgs after seeking to 'offset'. Returns the #msgs consumed.
  public int
  consumeMsgs(String topicName, long offset) throws Exception {
    TopicPartition p0 = new TopicPartition(topicName, 0);
    List<TopicPartition> plist = new ArrayList<TopicPartition>();
    plist.add(p0);
    consumer.assign(plist);

    consumer.seek(p0, offset);

    boolean terminated = false;
    int numMsgs = 0;
    while (!terminated) {
      ConsumerRecords<String, String> consumerRecs = consumer.poll(10);
      Iterator<ConsumerRecord<String, String>>iter = consumerRecs.iterator();

      while (iter.hasNext()) {
        ConsumerRecord<String, String> record = iter.next();
        if (record.value().equals("terminator")) {
          terminated = true;
          break;
        } else {
          ++numMsgs;
        }
      }
    }
    consumer.unsubscribe();

    return numMsgs;
  }

  private void
  writeMsgs(KafkaProducer kp, String topicName, int nMsgs) throws Exception {
    String msgValue = "seekToEndMessage";
  
    for (int i = 0; i < nMsgs; i++)
      kp.send(new ProducerRecord<String, String>(topicName, Integer.toString(i), msgValue));

    kp.flush();
  }

  private void testSeekToEnd0(int topicCount, int emptyTopicIdx) throws Exception {
    String[] topicName = new String [topicCount];
    int[] nMsgs = new int [topicCount];
    String streamName = STREAM + "-SeekToEnd0-" + emptyTopicIdx;
  
    CreateStream(streamName);
  
    KafkaProducer<String, String> pr = new KafkaProducer<String, String>(GetProducerProps());
  
    // create topics
    for (int i = 0 ; i < topicCount ; i++) {
      nMsgs[i] = (i + 1) * 1000;
      topicName[i] = "t" + i;
      madmin.createTopic(streamName, topicName[i], numParts);
    }
  
    // produce messages for all but one topic
    for (int i = 0 ; i < topicCount ; i++)
      if (i != emptyTopicIdx)
        writeMsgs(pr, streamName + ":" + topicName[i], nMsgs[i]);
  
    KafkaConsumer<String, String> cs = new KafkaConsumer<String, String>(GetConsumerProps());
    List<TopicPartition> topicPartitionList = new ArrayList<TopicPartition>();
  
    for (int i = 0 ; i < topicCount ; i++)
      topicPartitionList.add(new TopicPartition(streamName + ":" + topicName[i], 0));
  
    cs.assign(topicPartitionList);
  
    // seek to the end of the topic and check that
    // the offset/cursor position after this operation is correct
    for (int i = 0 ; i < topicCount ; i++) {
      cs.seekToEnd(topicPartitionList.get(i));
      long seekToEndPosition = cs.position(topicPartitionList.get(i));
      assertTrue(seekToEndPosition == ((i != emptyTopicIdx) ? nMsgs[i] : 0) + 1);
    }

    pr.close();
    cs.close();
  
    madmin.deleteStream(streamName);
  }
  
  private void testSeekToEnd0() throws Exception {
    int topicCount = 3;
  
    for (int i = 0 ; i < topicCount ; i++)
      testSeekToEnd0(topicCount, i);
  }

  private void testSeekToEnd1() throws Exception {
    int topicCount = 5;
    String[] topicName = new String [topicCount];
    int[] nMsgs = new int [topicCount];
    String streamName = STREAM + "-SeekToEnd1";
  
    CreateStream(streamName);
  
    KafkaProducer<String, String> pr = new KafkaProducer<String, String>(GetProducerProps());
  
    // create topics
    for (int i = 0 ; i < topicCount ; i++) {
      nMsgs[i] = (i == 1 || i == 3) ? 100 : 0;
      topicName[i] = "t" + i;
      madmin.createTopic(streamName, topicName[i], numParts);
    }

    KafkaConsumer<String, String> cs = new KafkaConsumer<String, String>(GetConsumerProps());
    List<TopicPartition> topicPartitionList = new ArrayList<TopicPartition>();
  
    for (int i = 0 ; i < topicCount ; i++)
      topicPartitionList.add(new TopicPartition(streamName + ":" + topicName[i], 0));
  
    cs.assign(topicPartitionList);

    // Step 1.
    for (int i = 0 ; i < topicCount ; i++)
      if (nMsgs[i] != 0)
        writeMsgs(pr, streamName + ":" + topicName[i], nMsgs[i]);

    // seek to the end of the topic and check that
    // the offset/cursor position after this operation is correct
    for (int i = 0 ; i < topicCount ; i++) {
      cs.seekToEnd(topicPartitionList.get(i));
      long seekToEndPosition = cs.position(topicPartitionList.get(i));
      assertTrue(seekToEndPosition == nMsgs[i] + 1);
    }

    // Step 2.
    for (int i = 0 ; i < topicCount ; i++)
      if (nMsgs[i] != 0) {
        final int nMessages = 100000;
        nMsgs[i] += nMessages;
        writeMsgs(pr, streamName + ":" + topicName[i], nMessages);
      }

    // seek to the end of the topic and check that
    // the offset/cursor position after this operation is correct
    for (int i = 0 ; i < topicCount ; i++) {
      cs.seekToEnd(topicPartitionList.get(i));
      long seekToEndPosition = cs.position(topicPartitionList.get(i));
      assertTrue(seekToEndPosition == nMsgs[i] + 1);
    }

    // Step 3.
    for (int i = 0 ; i < topicCount ; i++)
      if (nMsgs[i] == 0) {
        nMsgs[i] = 1;
        writeMsgs(pr, streamName + ":" + topicName[i], nMsgs[i]);
      }

    // seek to the end of the topic and check that
    // the offset/cursor position after this operation is correct
    for (int i = 0 ; i < topicCount ; i++) {
      cs.seekToEnd(topicPartitionList.get(i));
      long seekToEndPosition = cs.position(topicPartitionList.get(i));
      assertTrue(seekToEndPosition == nMsgs[i] + 1);
    }

    pr.close();
    cs.close();
  
    madmin.deleteStream(streamName);
  }
  
  @Test
  public void testSeekToEnd() throws Exception {
    testSeekToEnd0();
    testSeekToEnd1();
  }

  @Test
  public void testProducerOffsetsSmallMsgs() throws Exception {
    String topicName = STREAM + ":psmall";
    int numProduced = 2000;

    List<Long> offsets = produceMsgs(topicName, 1 /*msgSz*/, numProduced);
    for (int i = 1; i < numProduced; ++i)
      assertTrue(offsets.get(i) > offsets.get(i - 1));
  }

  @Test
  public void testProducerOffsetsLargeMsgs() throws Exception {
    String topicName = STREAM + ":plarge";
    int numProduced = 1000;

    List<Long> offsets = produceMsgs(topicName, 1024 /*msgSz*/, numProduced);
    for (int i = 1; i < numProduced; ++i)
      assertTrue(offsets.get(i) > offsets.get(i - 1));
  }

  @Test
  public void testSeeksWithSmallMsgs() throws Exception {
    String topicName = STREAM + ":small";
    int numProduced = 200;

    List<Long> offsets = produceMsgs(topicName, 1 /*msgSz*/, numProduced);

    /* Offset 0 should return all msgs */
    int numConsumed = consumeMsgs(topicName, 0 /*offset*/);
    _logger.info("consumed " + numConsumed + " msgs starting from offset 0");
    if (numConsumed != numProduced) {
      _logger.error("consumed " + numConsumed +
                    " msgs starting from offset 0, expected " +
                    numProduced);
      assertTrue(false);
    }


    // For each offset, should get all subsequent msgs.
    for (int i = 0; i < numProduced; ++i) {
      numConsumed = consumeMsgs(topicName, offsets.get(i));
      if (numConsumed != numProduced - i) {
        _logger.error("consumed " + numConsumed +
                      " msgs starting from offset " + offsets.get(i) +
                      ", expected " + (numProduced - i));
        assertTrue(false);
      }
    }
  }

  @Test
  public void testSeeksWithLargeMsgs() throws Exception {
    String topicName = STREAM + ":large";
    int numProduced = 100;

    List<Long> offsets = produceMsgs(topicName, 1000 /*msgSz*/, numProduced);

    /* Offset 0 should return all msgs */
    int numConsumed = consumeMsgs(topicName, 0 /*offset*/);
    if (numConsumed != numProduced) {
      _logger.error("consumed " + numConsumed +
                    " msgs starting from offset 0, expected " +
                    numProduced);
      assertTrue(false);
    }

    // For each offset, should get all subsequent msgs.
    for (int i = 0; i < numProduced; ++i) {
      numConsumed = consumeMsgs(topicName, offsets.get(i));
      if (numConsumed != numProduced - i) {
        _logger.error("consumed " + numConsumed +
                      " msgs starting from offset " + offsets.get(i) +
                      ", expected " + (numProduced - i));
        assertTrue(false);
      }
    }
  }

}
