/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.streams.tests;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.regex.Pattern;

import org.apache.hadoop.conf.Configuration;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.ojai.Document;
import org.ojai.DocumentStream;
import org.ojai.FieldPath;
import org.ojai.store.DocumentStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.mapr.db.exceptions.TableNotFoundException;
import com.mapr.streams.Admin;
import com.mapr.streams.StreamDescriptor;
import com.mapr.streams.Streams;
import com.mapr.streams.impl.MessageStore;
import com.mapr.tests.BaseTest;
import com.mapr.tests.annotations.ClusterTest;

@Category(ClusterTest.class)
public class BasicAnalyticsTest extends BaseTest {
  private static final Logger _logger = LoggerFactory.getLogger(BasicAnalyticsTest.class);
  private static final String PREFIX = "/jtest-" + BasicAnalyticsTest.class.getSimpleName() + "-";
  private static final String STREAM = PREFIX + "smallStream";
  private static Admin madmin;
  private static final int FEEDS = 26;

  static class TopicFeedMsg {
    public int topicId;
    public int feed;
    public int numMsgs;

    TopicFeedMsg(int t, int f, int m) {
      topicId = t;
      feed = f;
      numMsgs = m;
    }
  }

  // This list allows us to track how many messages have been generated
  // for which topic. This is for sanity checking
  private static List<TopicFeedMsg> allMsgsData;

  private static int totalMsgs;
  private static int totalExpectedSplits;

  @BeforeClass
  public static void setupTest() throws Exception {
    final Configuration conf = new Configuration();
    madmin = Streams.newAdmin(conf);

    String seedArg = System.getProperty("seed");
    long seed = -1;

    if (seedArg != null) {
      seed = Long.parseLong(seedArg);
    }

    try {
      // Cleanup stale stream
      madmin.deleteStream(STREAM);
    } catch(TableNotFoundException e) { }

    // Create new stream
    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(FEEDS);

    madmin.createStream(STREAM, sdesc);

    // Create 26 topics 'topicA'(1 partition) - 'topicZ'(26 partitions)
    char c = 'A';
    int numFeeds = 1;

    totalExpectedSplits = 0;

    while (true) {
      madmin.createTopic(STREAM, "topic" + c, numFeeds);
      totalExpectedSplits += numFeeds;

      c ++;
      numFeeds ++;
      if (c > 'Z') {
        break;
      }
    }

    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("streams.parallel.flushers.per.partition", false);  // Do not allow multiple flushers

    KafkaProducer producer = new KafkaProducer<byte[], byte[]>(props);
    Callback cb = new ProducerCallback();

    allMsgsData = new ArrayList<TopicFeedMsg>();

    Random rng = null;

    if (seed != -1) {
      rng = new Random(seed);
    } else {
      rng = new Random();
    }

    totalMsgs = 0;
    for (int i = 0; i < 50; i ++) {
      int topicId = rng.nextInt(26);
      int feed;

      if (topicId == 0) {
        feed = 0;
      } else {
        feed    = rng.nextInt(topicId);
      }

      int numMsgs = rng.nextInt(50); // max 50 messages

      for (int j = 0; j < numMsgs; j ++) {
        char ch = (char)('A' + topicId);
        String topicName = STREAM + ":topic" + ch;
        byte[] key = new byte[10];
        byte[] val = new byte[10];
        ProducerRecord<byte[], byte[]> record =
          new ProducerRecord<byte[], byte[]>(topicName, feed, key, val);
        producer.send(record, cb);
      }

      producer.flush();
      allMsgsData.add(new TopicFeedMsg(topicId, feed, numMsgs));
      totalMsgs += numMsgs;
    }
    producer.close();
  }

  private int countMessages(DocumentStream rs) throws Exception {
    Iterator<Document> iter = rs.iterator();
    int count = 0;
    while (iter.hasNext()) {
      count ++;
      iter.next();
    }

    return count;
  }

  private int countMessages(DocumentStore store) throws Exception {
    DocumentStream rs = store.find();
    return countMessages(rs);
  }

  @Test
  public void testAllMsgs() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM);
    assertEquals(totalMsgs, countMessages(store));
    assertEquals(totalExpectedSplits, ((MessageStore)store).getNumSplits());
  }

  private int getNumExpectedMsgs(int topicId) {
    int expectedMsgs = 0;
    for (TopicFeedMsg tfm : allMsgsData) {
      if (tfm.topicId == topicId) {
        expectedMsgs += tfm.numMsgs;
      }
    }

    return expectedMsgs;
  }

  @Test
  public void testSingleTopicMsgs() throws Exception {
    Random rng = new Random();

    // Choose 10 random topics and verify that we get all the 
    // expected number of messages per topic
    for (int i = 0; i < 10; i ++) {
      int topicId = rng.nextInt(26);
      char ch = (char)('A' + topicId);
      String topicName = "topic" + ch;

      DocumentStore store = Streams.getMessageStore(STREAM, topicName);
      assertEquals(getNumExpectedMsgs(topicId), countMessages(store));
      assertEquals(topicId + 1, ((MessageStore)store).getNumSplits());
    }
  }

  @Test
  public void testMultiTopicMsgs() throws Exception {
    Random rng = new Random();
    // Choose 10 random topics and verify that we get all the 
    // expected number of messages across them
    Set<String> topics = new HashSet<String>();
    int expectedMsgs = 0;
    int expectedSplits = 0;
    for (int i = 0; i < 15; i ++) {
      int topicId = rng.nextInt(26);
      char ch = (char)('A' + topicId);
      if (topics.add(new String("topic" + ch))) {
        expectedMsgs += getNumExpectedMsgs(topicId);
        expectedSplits += topicId + 1;
      }
    }

    String[] topicArr = topics.toArray(new String[topics.size()]);
    DocumentStore store = Streams.getMessageStore(STREAM, topicArr);
    assertEquals(expectedMsgs, countMessages(store));
    assertEquals(expectedSplits, ((MessageStore)store).getNumSplits());
  }

  @Test
  public void testRegexTopicMsgs() throws Exception {
    Random rng = new Random();
    // Choose 10 random topics and verify that we get all the 
    // expected number of messages across them
    Set<String> topics = new HashSet<String>();
    int expectedMsgs = 0;
    int expectedSplits = 0;
    for (int i = 0; i < 15; i ++) {
      int topicId = rng.nextInt(26);
      char ch = (char)('A' + topicId);
      if (topics.add(new String("topic" + ch))) {
        expectedMsgs += getNumExpectedMsgs(topicId);
        expectedSplits += topicId + 1;
      }
    }

    String regex = "(";
    int i = 0;
    for (String topic : topics) {
      if (i > 0) {
        regex += "|";
      }

      regex += topic;
      i ++;
    }

    regex += ")";
    DocumentStore store = Streams.getMessageStore(STREAM, Pattern.compile(regex));
    assertEquals(expectedMsgs, countMessages(store));
    assertEquals(expectedSplits, ((MessageStore)store).getNumSplits());
  }

  @Test
  public void testNonExistantTopic() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM, "topicNonExistant");
    assertEquals(0, countMessages(store));
  }


  @Test
  public void testNullTopic() throws Exception {
    boolean exceptionCaught = false;

    try {
      DocumentStore store = Streams.getMessageStore(STREAM, (String)null);
    } catch (IllegalArgumentException e) {
      exceptionCaught = true;
    }

    assertTrue(exceptionCaught);
  }

  @Test
  public void testProjectionTopicKey() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM);
    DocumentStream rs = store.find(Streams.TOPIC,Streams.KEY);
    Iterator<Document> iter = rs.iterator();

    while (iter.hasNext()) {
      Document doc = iter.next();

      assertNotEquals(null, doc.getBinary(Streams.KEY));

      assertEquals(null, doc.getBinary(Streams.VALUE));
      boolean exceptionCaught;

      exceptionCaught = false;
      try {
        int feed     = doc.getInt(Streams.PARTITION);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      exceptionCaught = false;
      try {
        long offset = doc.getLong(Streams.OFFSET);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      assertEquals(null, doc.getString(Streams.PRODUCER));
      assertNotEquals(null, doc.getString(Streams.TOPIC));
    }
  }

  @Test
  public void testProjectionTopicPartition() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM);
    DocumentStream rs = store.find(Streams.TOPIC, Streams.PARTITION);
    Iterator<Document> iter = rs.iterator();

    int count = 0;
    while (iter.hasNext()) {
      Document doc = iter.next();

      assertEquals(null, doc.getBinary(Streams.KEY));

      assertEquals(null, doc.getBinary(Streams.VALUE));
      boolean exceptionCaught;

      exceptionCaught = false;
      try {
        int feed     = doc.getInt(Streams.PARTITION);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertFalse(exceptionCaught);

      exceptionCaught = false;
      try {
        long offset = doc.getLong(Streams.OFFSET);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      assertEquals(null, doc.getString(Streams.PRODUCER));
      assertNotEquals(null, doc.getString(Streams.TOPIC));
      count ++;
    }

    assertEquals(totalMsgs, count);
  }

  @Test
  public void testProjectionNull() throws Exception {
    boolean exceptionCaught = false;

    DocumentStore store = Streams.getMessageStore(STREAM);

    try {
      DocumentStream rs = store.find((FieldPath)null);
    } catch (IllegalArgumentException e) {
      exceptionCaught = true;
    }

    assertTrue(exceptionCaught);
  }

  @Test
  public void testProjectionKeyValueProducer() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM);
    DocumentStream rs = store.find(Streams.PRODUCER, Streams.KEY, Streams.VALUE);
    Iterator<Document> iter = rs.iterator();

    while (iter.hasNext()) {
      Document doc = iter.next();

      assertNotEquals(null, doc.getBinary(Streams.KEY));

      assertNotEquals(null, doc.getBinary(Streams.VALUE));
      boolean exceptionCaught;

      exceptionCaught = false;
      try {
        int feed     = doc.getInt(Streams.PARTITION);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      exceptionCaught = false;
      try {
        long offset = doc.getLong(Streams.OFFSET);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      assertNotEquals(null, doc.getString(Streams.PRODUCER));
      assertEquals(null, doc.getString(Streams.TOPIC));
    }
  }

/*
  @Test
  public void testInvalidProjections() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM);

    boolean exceptionCaught = false;

    try {
      exceptionCaught = false;
      DocumentStream<Document>rs = (DocumentStream<Document>)store.find(Streams.TIMESTAMP);
    } catch (UnsupportedOperationException e) {
      exceptionCaught = true;
    }

    assertTrue(exceptionCaught);
  }
*/

  @Test
  public void testProjectionProducer() throws Exception {
    DocumentStore store = Streams.getMessageStore(STREAM);
    DocumentStream rs = store.find(Streams.PRODUCER);
    Iterator<Document> iter = rs.iterator();

    int count = 0;
    while (iter.hasNext()) {
      Document doc = iter.next();

      assertEquals(null, doc.getBinary(Streams.KEY));

      assertEquals(null, doc.getBinary(Streams.VALUE));
      boolean exceptionCaught;

      exceptionCaught = false;
      try {
        int feed     = doc.getInt(Streams.PARTITION);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      exceptionCaught = false;
      try {
        long offset = doc.getLong(Streams.OFFSET);
      } catch (NoSuchElementException e) {
        exceptionCaught = true;
      }

      assertTrue(exceptionCaught);

      assertNotEquals(null, doc.getString(Streams.PRODUCER));
      assertEquals(null, doc.getString(Streams.TOPIC));
      count ++;
    }

    assertEquals(totalMsgs, count);
  }

 private static final class ProducerCallback implements Callback {

    public ProducerCallback() {
    }

    public void onCompletion(RecordMetadata metadata,
                             Exception exception) {
      if (exception != null) {
        exception.printStackTrace();
        System.exit(1);
      }
    }
  }
}
