/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.streams.tests.listener;

import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.*;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.regex.Pattern;

import org.apache.hadoop.conf.Configuration;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.errors.WakeupException;
import org.apache.kafka.clients.producer.KafkaProducer;

import org.apache.kafka.common.serialization.ByteArrayDeserializer;

import com.mapr.tests.BaseTest;
import com.mapr.tests.annotations.ClusterTest;

import com.mapr.streams.Admin;
import com.mapr.streams.Streams;
import com.mapr.streams.StreamDescriptor;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;


@Category(ClusterTest.class)
public class ListenerPauseResumeWakeupTest extends BaseTest {
  private static final Logger _logger = LoggerFactory.getLogger(ListenerPauseResumeWakeupTest.class);
  private static final String STREAM = "/jtest-" + ListenerPauseResumeWakeupTest.class.getSimpleName();
  private static Admin madmin;
  private static final int numPartitions = 10;

  @BeforeClass
  public static void setupTestClass() throws Exception {
    final Configuration conf = new Configuration();
    madmin = Streams.newAdmin(conf);

    //Cleanup all stale streams
    try {
      madmin.deleteStream(STREAM);
    } catch (Exception e) {}

    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(numPartitions);
    madmin.createStream(STREAM, sdesc);
  }

  @AfterClass
  public static void cleanupTestClass() throws Exception {
    madmin.deleteStream(STREAM);
  }

  @Test
  public void testPauseResume() throws IOException {
    String topicname = "producertest";
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();

    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getParallelFlushersPerPartition(), "false");
    KafkaProducer producer = new KafkaProducer<byte[], byte[]>(props);

    props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put(cdef.getMetadataMaxAge(), 3000);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);

    for (int i = 0; i < numPartitions; ++i) {
      String key = "key-value" + i;
      String msg = "msg-value" + i;
      ProducerRecord<byte[], byte[]> record =
        new ProducerRecord<byte[], byte[]>(STREAM+":"+topicname,
                                           i,
                                           key.getBytes(),
                                           msg.getBytes());
      // System.out.println("sending " + i + " " + record);
      producer.send(record);
    }

    producer.flush();

    List<String> topics = new ArrayList<String>(1);
    topics.add(STREAM+":"+topicname);
    consumer.subscribe(topics);

    ConsumerRecords<byte[], byte[]> recs = consumer.poll(1000);
    Set<Integer> resultSet = new HashSet<Integer>();
    Set<Integer> expectedSet = new HashSet<Integer>();
    for (int i = 0; i < numPartitions; ++i) {
      expectedSet.add(i);
    }

    Iterator<ConsumerRecord<byte[], byte[]>> iter = recs.iterator();
    while(iter.hasNext()) {
      ConsumerRecord<byte[], byte[]> rec = iter.next();
      resultSet.add(rec.partition());
    }

    assertTrue(resultSet.size() == numPartitions);
    assertTrue(expectedSet.equals(resultSet));

    for (int i = 0; i < numPartitions; ++i) {
      String key = "key-value" + i;
      String msg = "msg-value" + i;
      ProducerRecord<byte[], byte[]> record =
        new ProducerRecord<byte[], byte[]>(STREAM+":"+topicname,
                                           i,
                                           key.getBytes(),
                                           msg.getBytes());
      // System.out.println("sending " + i + " " + record);
      producer.send(record);
    }
    producer.flush();

    TopicPartition[] tps = new TopicPartition[numPartitions/2];
    for (int i = 0; i < numPartitions/2; ++i) {
      tps[i] = new TopicPartition(STREAM+":"+topicname, i);
    }
    consumer.pause(tps);

    recs = consumer.poll(1000);
    resultSet = new HashSet<Integer>();
    expectedSet = new HashSet<Integer>();
    for (int i = numPartitions/2; i < numPartitions; ++i) {
      expectedSet.add(i);
    }

    iter = recs.iterator();
    while(iter.hasNext()) {
      ConsumerRecord<byte[], byte[]> rec = iter.next();
      resultSet.add(rec.partition());
    }

    assertTrue(resultSet.size() == (numPartitions - numPartitions/2));
    assertTrue(expectedSet.equals(resultSet));

    consumer.resume(tps);

    recs = consumer.poll(1000);
    resultSet = new HashSet<Integer>();
    expectedSet = new HashSet<Integer>();
    for (int i = 0; i < numPartitions/2; ++i) {
      expectedSet.add(i);
    }

    iter = recs.iterator();
    while(iter.hasNext()) {
      ConsumerRecord<byte[], byte[]> rec = iter.next();
      resultSet.add(rec.partition());
    }

    assertTrue(resultSet.size() == numPartitions/2);
    assertTrue(expectedSet.equals(resultSet));

    producer.close();
    consumer.close();
  }

  @Test
  public void testWakeup() throws Exception {
    String topicname = "producertest";
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();

    Properties props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put(cdef.getMetadataMaxAge(), 3000);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);

    List<String> topics = new ArrayList<String>(1);
    topics.add(STREAM+":RANDOMTOPIC");
    consumer.subscribe(topics);

    PollForLongTime worker = new PollForLongTime(consumer);
    Thread workerThread = new Thread(worker);
    workerThread.start();

    Thread.sleep(1*1000);
    consumer.wakeup();

    workerThread.join();
    assertTrue(worker.verify());

    consumer.close();
  }

  @Test
  public void testWakeupBeforePoll() throws Exception {
    String topicname = "producertest";
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();

    Properties props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put(cdef.getMetadataMaxAge(), 3000);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);

    List<String> topics = new ArrayList<String>(1);
    topics.add(STREAM+":RANDOMTOPIC");
    consumer.subscribe(topics);
    consumer.wakeup();

    PollForLongTime worker = new PollForLongTime(consumer);
    Thread workerThread = new Thread(worker);
    workerThread.start();

    workerThread.join();
    assertTrue(worker.verify());

    consumer.close();
  }

  @Test
  public void testPauseResumeSameFeed() throws IOException {
    String topicname = "producertest";
    int numPart = 1;
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();

    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getParallelFlushersPerPartition(), "false");
    KafkaProducer producer = new KafkaProducer<byte[], byte[]>(props);

    props = new Properties();
    props.put("key.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("value.deserializer",
              "org.apache.kafka.common.serialization.ByteArrayDeserializer");
    props.put("auto.offset.reset", "earliest");
    props.put("fetch.min.bytes", "1");
    props.put(cdef.getMetadataMaxAge(), 2000);
    KafkaConsumer consumer = new KafkaConsumer<byte[], byte[]>(props);
    List<String> topics = new ArrayList<String>(1);
    topics.add(STREAM+":"+topicname);
    consumer.subscribe(topics);
    TopicPartition[] tps = new TopicPartition[1];
    tps[0] = new TopicPartition(STREAM+":"+topicname, 0);

    //First msg
    for (int i = 0; i < 10; i++)
    {
      int numMsgsConsumed = 0;
      consumer.pause(tps);
      String key = "key-value" + i;
      String msg = "msg-value" + i;
      ProducerRecord<byte[], byte[]> record = new ProducerRecord<byte[], byte[]>(STREAM+":"+topicname,
                                                  0, key.getBytes(), msg.getBytes());
      producer.send(record);
      consumer.resume(tps);
      producer.flush();

      ConsumerRecords<byte[], byte[]> recs = consumer.poll(5000);
      Iterator<ConsumerRecord<byte[], byte[]>> iter = recs.iterator();
      while(iter.hasNext()) {
        ConsumerRecord<byte[], byte[]> rec = iter.next();
        String keyStr = new String(rec.key(), "UTF-8");
        assertTrue(keyStr.equals(key));
        numMsgsConsumed++;
        // System.out.println("consuming key " + keyStr);
      }

      assertTrue(numMsgsConsumed == 1);
      consumer.pause(tps);
    }

    producer.close();
    consumer.close();
  }

  public class PollForLongTime implements Runnable {
    private KafkaConsumer consumer;
    private boolean interrupted;
    private boolean correctTypeOfException;

    public PollForLongTime(KafkaConsumer c) {
      consumer = c;
      interrupted = false;
      correctTypeOfException = false;
    }

    public void run() {
      try {
        System.out.println("Polling...");
        consumer.poll(1000000000L); // never return!
        System.out.println("Done polling...");
      } catch (Exception e) {
        System.out.println("Interrupted polling thread with wake up, " + e);
        interrupted = true;
        if (e instanceof WakeupException) {
          correctTypeOfException = true;
        }
      }
    }

    public boolean verify() {
      return (interrupted && correctTypeOfException);
    }
  }


}
