/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.streams.tests.listener;

import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.Future;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.regex.Pattern;

import org.apache.hadoop.conf.Configuration;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.OffsetCommitCallback;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.common.serialization.ByteArrayDeserializer;

import com.mapr.tests.BaseTest;
import com.mapr.tests.annotations.ClusterTest;

import com.mapr.streams.Admin;
import com.mapr.streams.Streams;
import com.mapr.streams.StreamDescriptor;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;

@Category(ClusterTest.class)
public class ProducerAndListenerDefaultStream extends BaseTest {
  private static final Logger _logger = LoggerFactory.getLogger(ProducerAndListenerDefaultStream.class);
  private static final String STREAM = "/jtest-" + ProducerAndListenerDefaultStream.class.getSimpleName();
  private static final String DEFAULTSTREAM = "/jtest-DEFAULTSTREAM";
  private static final String STREAMLISTTOPIC = STREAM + "-listtopic";
  private static final String DEFAULTSTREAMLISTTOPIC = DEFAULTSTREAM + "-listtopic";
  private static Admin madmin;
  private static final int numParts = 1;

  @BeforeClass
  public static void setupTestClass() throws Exception {
    final Configuration conf = new Configuration();
    madmin = Streams.newAdmin(conf);

    //Cleanup all stale streams
    try {
      madmin.deleteStream(STREAM);
    } catch (Exception e) {}
    try {
      madmin.deleteStream(DEFAULTSTREAM);
    } catch (Exception e) {}
    try {
      madmin.deleteStream(STREAMLISTTOPIC);
    } catch (Exception e) {}
    try {
      madmin.deleteStream(DEFAULTSTREAMLISTTOPIC);
    } catch (Exception e) {}

    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(numParts);
    madmin.createStream(STREAM, sdesc);
    madmin.createStream(DEFAULTSTREAM, sdesc);
    madmin.createStream(STREAMLISTTOPIC, sdesc);
    madmin.createStream(DEFAULTSTREAMLISTTOPIC, sdesc);
  }

  @AfterClass
  public static void cleanupTestClass() throws Exception {
    madmin.deleteStream(STREAM);
    madmin.deleteStream(DEFAULTSTREAM);
    madmin.deleteStream(STREAMLISTTOPIC);
    madmin.deleteStream(DEFAULTSTREAMLISTTOPIC);
  }

  public void testProducerAndConsumer(int numMsgs, String topicname, KafkaProducer producer, KafkaConsumer consumer) throws Exception {
    Future[] futuresDefault = new Future[numMsgs];
    Future[] futuresNonDefault = new Future[numMsgs];
    for (int i = 0; i < numMsgs; ++i) {
      String key = "key-value" + i;
      String msg = "msg-value" + i;
      ProducerRecord<byte[], byte[]> record =
        new ProducerRecord<byte[], byte[]>(STREAM+":"+topicname,
                                           key.getBytes(),
                                           msg.getBytes());
      // System.out.println("sending " + i + " " + record);
      futuresNonDefault[i] = producer.send(record);
      record =
        new ProducerRecord<byte[], byte[]>(topicname,
                                           key.getBytes(),
                                           msg.getBytes());
      futuresDefault[i] = producer.send(record);
      // System.out.println("sending " + i + " " + record);
    }

    producer.flush();
    producer.close();

    long offset = -1;
    for (Future<RecordMetadata> future : futuresDefault) {
      RecordMetadata rm = future.get();
      assertTrue(rm.partition()  == 0 );
      assertTrue(offset < rm.offset());
      offset = rm.offset();
      assertTrue(rm.topic().equals(DEFAULTSTREAM+":"+topicname));
      // System.out.println("got " + rm.topic() + " " + rm.partition() + " " + rm.offset());
    }
    offset = -1;
    for (Future<RecordMetadata> future : futuresNonDefault) {
      RecordMetadata rm = future.get();
      assertTrue(rm.partition()  == 0 );
      assertTrue(offset < rm.offset());
      offset = rm.offset();
      assertTrue(rm.topic().equals(STREAM+":"+topicname));
      // System.out.println("got " + rm.topic() + " " + rm.partition() + " " + rm.offset());
    }

    if (consumer == null)
      return;

    List<String> topics = new ArrayList<String>(2);
    topics.add(topicname);
    topics.add(STREAM+":"+topicname);

    RebalanceCb cb = new RebalanceCb();
    consumer.subscribe(topics, cb);

    cb.assignDone();

    boolean done = false;
    int countDefault = 0;
    int countNonDefault = 0;
    while (!done) {
      ConsumerRecords<byte[], byte[]> recs = consumer.poll(1000);

      Iterator<ConsumerRecord<byte[], byte[]>> defaultTopic =
        recs.records(new TopicPartition(DEFAULTSTREAM+":"+topicname, 0)).iterator();

      while (defaultTopic.hasNext()) {
        ConsumerRecord<byte[], byte[]> oneRecord  = defaultTopic.next();
        //System.out.println(countDefault + " " + oneRecord);
        assertTrue(oneRecord.partition() == 0);
        assertTrue(oneRecord.topic().equals(DEFAULTSTREAM+":"+topicname));
        countDefault++;
      }

      Iterator<ConsumerRecord<byte[], byte[]>> nonDefaultTopic =
        recs.records(new TopicPartition(STREAM+":"+topicname, 0)).iterator();

      while (nonDefaultTopic.hasNext()) {
        ConsumerRecord<byte[], byte[]> oneRecord  = nonDefaultTopic.next();
        //System.out.println(countNonDefault + " " + oneRecord);
        assertTrue(oneRecord.partition() == 0);
        assertTrue(oneRecord.topic().equals(STREAM+":"+topicname));
        countNonDefault++;
      }

      if (countDefault == numMsgs && countNonDefault == numMsgs) {
        done = true;
      }
    }

    consumer.close();

  }

  @Test
  public void testDefaultStreamNameForProducer() throws Exception {
    int numMsgs = 10;
    String topicname = "producertest";
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getParallelFlushersPerPartition(), "false");
    props.put("streams.producer.default.stream", DEFAULTSTREAM);
    KafkaProducer producer = new KafkaProducer<byte[], byte[]>(props);

    testProducerAndConsumer(numMsgs, topicname, producer, null);
  }

  @Test
  public void testDefaultStreamNameForConsumer() throws Exception {
    int numMsgs = 10;
    String topicname = "consumertest";

    madmin.createTopic(DEFAULTSTREAM, topicname);
    madmin.createTopic(STREAM, topicname);

    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();

    Properties props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put("streams.consumer.default.stream", DEFAULTSTREAM);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);

    List<String> topics = new ArrayList<String>(2);
    topics.add(topicname);
    topics.add(STREAM+":"+topicname);

    RebalanceCb cb = new RebalanceCb();
    consumer.subscribe(topics, cb);

    cb.assignDone();

    Set<String> subscription = consumer.subscription();
    Set<TopicPartition> assignment = consumer.assignment();

    Set<String> correctSubscription = new HashSet<String>();
    correctSubscription.add(DEFAULTSTREAM+":"+topicname);
    correctSubscription.add(STREAM+":"+topicname);

    Set<TopicPartition> correctAssignment = new HashSet<TopicPartition>();
    correctAssignment.add(new TopicPartition(DEFAULTSTREAM+":"+topicname, 0));
    correctAssignment.add(new TopicPartition(STREAM+":"+topicname, 0));

    // Since numParts == 1.
    assertTrue(subscription.size() == assignment.size());
    assertTrue(correctSubscription.equals(subscription));
    assertTrue(correctAssignment.equals(assignment));

    assertTrue(consumer.position(new TopicPartition(topicname, 0)) == 0);

    Map<TopicPartition, OffsetAndMetadata> toCommit = new HashMap<TopicPartition, OffsetAndMetadata>();
    toCommit.put(new TopicPartition(topicname, 0), new OffsetAndMetadata(100L));
    toCommit.put(new TopicPartition(STREAM+":"+topicname, 0), new OffsetAndMetadata(200L));

    consumer.commitSync(toCommit);

    assertTrue(consumer.committed(new TopicPartition(topicname, 0)).offset() == 100L);
    assertTrue(consumer.committed(new TopicPartition(DEFAULTSTREAM+":"+topicname, 0)).offset() == 100L);
    assertTrue(consumer.committed(new TopicPartition(STREAM+":"+topicname, 0)).offset() == 200L);

    consumer.unsubscribe();

    List<TopicPartition> partitions = new ArrayList<TopicPartition>(2);
    partitions.add(new TopicPartition(topicname, 0));
    partitions.add(new TopicPartition(STREAM+":"+topicname, 0));

    consumer.assign(partitions);

    subscription = consumer.subscription();
    assignment = consumer.assignment();

    assertTrue(subscription.isEmpty());
    assertTrue(assignment.size() == 2);
    assertTrue(correctAssignment.equals(assignment));

    consumer.close();
  }

  @Test
  public void testDefaultStreamNameForProducerAndConsumer() throws Exception {
    int numMsgs = 10;
    String topicname = "listenerPollTest";
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getParallelFlushersPerPartition(), "false");
    props.put("streams.producer.default.stream", DEFAULTSTREAM);
    KafkaProducer producer = new KafkaProducer<byte[], byte[]>(props);

    props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put("streams.consumer.default.stream", DEFAULTSTREAM);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);

    testProducerAndConsumer(numMsgs, topicname, producer, consumer);
  }

  @Test
  public void testListTopicsForConsumer() throws Exception {

    String topicname = "list";
    int numTopics = 31;
    for (int i = 0; i < numTopics; ++i) {
      madmin.createTopic(DEFAULTSTREAMLISTTOPIC, topicname+i, i+1);
      madmin.createTopic(STREAMLISTTOPIC, topicname+i, i+1);
    }

    Properties props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    // Don't set default stream.  This should return empty map!
    // props.put("streams.consumer.default.stream", DEFAULTSTREAMLISTTOPIC);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);

    Map<String, List<PartitionInfo>> listTopicWithoutString = null;
    Map<String, List<PartitionInfo>> listTopicWithString = null;

    // Test without default stream set.  This should return empty list!
    listTopicWithoutString = consumer.listTopics();
    assertTrue(listTopicWithoutString.size() == 0);

    // Now pass one of the stream (STREAMLISTTOPIC)
    listTopicWithString = consumer.listTopics(STREAMLISTTOPIC);
    assertTrue(listTopicWithString.size() == numTopics);
    String prefix = STREAMLISTTOPIC+":"+topicname;

    Set<Integer> seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);

      // System.out.println("topicname " + entry.getKey() + " " + key + " " + number + " " + entry.getValue().size());

      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == numTopics);

    // Now pass one of the stream (DEFAULTSTREAMLISTTOPIC)
    listTopicWithString = consumer.listTopics(DEFAULTSTREAMLISTTOPIC);
    assertTrue(listTopicWithString.size() == numTopics);
    prefix = DEFAULTSTREAMLISTTOPIC+":"+topicname;

    seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);
      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == numTopics);

    // Now try to do a regex subscription and see if it works!

    // Regex pattern that doesn't match anything
    String topicPattern = "listA.*";
    Pattern pattern = Pattern.compile(topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithoutString.size() == 0);

    // Regex pattern that doesn't match anything with stream name
    pattern = Pattern.compile(STREAMLISTTOPIC+":"+topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithoutString.size() == 0);

    // Regex pattern that does match some without stream name
    topicPattern = "list.$";
    pattern = Pattern.compile(topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithString.size() == 0);

    // Regex pattern that matches some with stream name
    pattern = Pattern.compile(STREAMLISTTOPIC+":"+topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithString.size() == 10);

    prefix = STREAMLISTTOPIC+":"+topicname;
    seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);
      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == 10);

    consumer.close();  // close consumer

    // Create another consumer with default set
    props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put("streams.consumer.default.stream", DEFAULTSTREAMLISTTOPIC);
    consumer = new KafkaConsumer<byte[], byte[]>(props);

    // Test with default stream set.
    listTopicWithoutString = consumer.listTopics();
    assertTrue(listTopicWithoutString.size() == numTopics);
    String nullString = null;
    listTopicWithoutString = consumer.listTopics(nullString);
    assertTrue(listTopicWithoutString.size() == numTopics);

    seenTopics = new HashSet<Integer>();
    prefix = DEFAULTSTREAMLISTTOPIC+":"+topicname;
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithoutString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);

      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == numTopics);

    // Now pass one of the stream (DEFAULTSTREAMLISTTOPIC)
    listTopicWithString = consumer.listTopics(DEFAULTSTREAMLISTTOPIC);
    assertTrue(listTopicWithString.size() == numTopics);
    prefix = DEFAULTSTREAMLISTTOPIC+":"+topicname;

    seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);
      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == numTopics);

    // Now pass one of the stream (STREAMLISTTOPIC)
    listTopicWithString = consumer.listTopics(STREAMLISTTOPIC);
    assertTrue(listTopicWithString.size() == numTopics);
    prefix = STREAMLISTTOPIC+":"+topicname;

    seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);
      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == numTopics);

    // Now try to do a regex subscription and see if it works!

    // Regex pattern that doesn't match anything
    topicPattern = "listA.*";
    pattern = Pattern.compile(topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithString.size() == 0);

    // Regex pattern that doesn't match anything with stream name
    pattern = Pattern.compile(STREAMLISTTOPIC+":"+topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithString.size() == 0);

    // Regex pattern that does match some without stream name
    topicPattern = "list.$";
    pattern = Pattern.compile(topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithString.size() == 10);

    prefix = DEFAULTSTREAMLISTTOPIC+":"+topicname;
    seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);
      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == 10);

    // Regex pattern that matches some with stream name
    pattern = Pattern.compile(STREAMLISTTOPIC+":"+topicPattern);
    listTopicWithString = consumer.listTopics(pattern);
    assertTrue(listTopicWithString.size() == 10);

    prefix = STREAMLISTTOPIC+":"+topicname;
    seenTopics = new HashSet<Integer>();
    for (Map.Entry<String, List<PartitionInfo>> entry : listTopicWithString.entrySet()) {
      String key = entry.getKey();
      assertTrue(key.startsWith(prefix));
      key = key.substring(prefix.length());
      int number = Integer.parseInt(key);
      seenTopics.add(number);
      assertTrue(entry.getValue().size() == number+1);
    }
    assertTrue(seenTopics.size() == 10);

    consumer.close();  // close consumer
  }

  public final class RebalanceCb implements ConsumerRebalanceListener { 
    private boolean revoked;
    private boolean assigned;
    public RebalanceCb() {
      revoked = false;
      assigned = false;
    }

    public synchronized void clear() {
      revoked = false;
      assigned = false;
    }

    public synchronized void revokeDone () {
      while (revoked == false) {
        try {
          this.wait();
        } catch (Exception e) {
        }
      }
      revoked = false;
    }

    public synchronized void assignDone() {
      while (assigned == false) {
        try {
          this.wait();
        } catch (Exception e) {
        }
      }
      assigned = false;
    }

    public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
      synchronized(this) {
        //System.out.println(this + " partition assigned " + partitions + " " + partitions.size());
        assigned = true;
        this.notifyAll();
      }
    }

    public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
      synchronized(this) {
        //System.out.println(this  + " partition revoke " + partitions + " " + partitions.size());
        revoked = true;
        this.notifyAll();
      }
    }
  }

}
