/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.streams.tests.producer;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.conf.Configuration;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.Ignore;
import java.lang.Thread;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Properties;
import com.mapr.streams.Admin;
import com.mapr.streams.StreamDescriptor;
import com.mapr.streams.Streams;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.common.PartitionInfo;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;
import com.mapr.tests.BaseTest;
import com.mapr.tests.annotations.ClusterTest;

@Category(ClusterTest.class)
public class ProducerMultiTest extends BaseTest {
  private static final Logger _logger = LoggerFactory.getLogger(ProducerMultiTest.class);
  private static final String STREAM = "/jtest-" + ProducerMultiTest.class.getSimpleName();
  private static final int numPartitions = 5;
  private static final int numThreads = 20;
  private static Admin madmin;

  @BeforeClass
  public static void setupTest() throws Exception {
    final Configuration conf = new Configuration();
    madmin = Streams.newAdmin(conf);
    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    try {
      madmin.deleteStream(STREAM);
    } catch (Exception e) {
      System.out.println(e);
    }
  }

  @Before
  public void setupTable() throws Exception {
    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(numPartitions);
    madmin.createStream(STREAM, sdesc);
  }

  @After
  public void cleanupTest() throws Exception {
    madmin.deleteStream(STREAM);
  }

  @Test
  public void testManyProducersTogetherDeleteStream() throws Exception {
    Thread[] producers = new Thread[numThreads];
    OneProducer[] oneproducers = new OneProducer[numThreads];

    for (int i = 0; i < numThreads; ++i) {
      OneProducer producer = new OneProducer(STREAM, numPartitions, 5);
      oneproducers[i] = producer;
      producers[i] = new Thread(producer);
    }

    for (int i = 0; i < numThreads; ++i) {
      producers[i].start();
    }

    madmin.deleteStream(STREAM);
    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(numPartitions);
    madmin.createStream(STREAM, sdesc);

    for (int i = 0; i < numThreads; ++i) {
      producers[i].join();
    }

    boolean exceptionsHappened = false;
    for (int i = 0; i < numThreads; ++i) {
      assertTrue(oneproducers[i].success());
      exceptionsHappened = (exceptionsHappened || oneproducers[i].exceptions());
    }

    assertTrue(exceptionsHappened);

    assertTrue(!ExistZombieThreads());
  }

  @Test
  public void testManyProducersTogether() throws Exception {

    Thread[] producers = new Thread[numThreads];
    OneProducer[] oneproducers = new OneProducer[numThreads];

    for (int i = 0; i < numThreads; ++i) {
      OneProducer producer = new OneProducer(STREAM, numPartitions, 10000);
      oneproducers[i] = producer;
      producers[i] = new Thread(producer);
    }

    for (int i = 0; i < numThreads; ++i) {
      producers[i].start();
    }

    for (int i = 0; i < numThreads; ++i) {
      producers[i].join();
    }

    for (int i = 0; i < numThreads; ++i) {
      assertTrue(oneproducers[i].success());
      assertTrue(!oneproducers[i].exceptions());
    }

    assertTrue(!ExistZombieThreads());
  }

  @Test
  public void testManyProducersTogetherWithSharedObject() throws Exception {

    int numWorkers = 100;
    Thread[] producers = new Thread[numWorkers];

    MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
    Properties props = new Properties();
    props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
    props.put(cdef.getMetadataMaxAge(), 100);  // want to exercise metadata refresher
    KafkaProducer kafkaproducer = new KafkaProducer<byte[], byte[]>(props);
    CountCallback cb = new CountCallback(10000 * numPartitions * numWorkers);

    for (int i = 0; i < numWorkers; ++i) {
      SendMessagesToProducer worker = new SendMessagesToProducer(kafkaproducer, cb,
                                                                 STREAM+":sharedtopic",
                                                                 numPartitions,
                                                                 10000);
      producers[i] = new Thread(worker);
    }

    for (int i = 0; i < numWorkers; ++i) {
      producers[i].start();
    }

    for (int i = 0; i < numWorkers; ++i) {
      producers[i].join();
    }

    assertTrue(cb.success());
    assertTrue(!cb.exceptions());


    assertTrue(!ExistZombieThreads());

  }

  public static boolean ExistZombieThreads() {
    boolean zombieThreads = false;
    boolean enableTrace = false;

    Map<Thread, StackTraceElement[]> stacks = Thread.getAllStackTraces();
    try {
      for(Thread t : stacks.keySet()) {
        StackTraceElement[] stes = stacks.get(t);
        boolean zombieThread = false;
        if (stes.length == 0) {
          continue;
        }

        for (StackTraceElement ste : stes) {
          if (ste.toString().toLowerCase().contains("marlin")) {
            // now make sure that the match is not only "com.mapr.streams.tests", since that is not what we want.
            if (ste.toString().replace("com.mapr.streams.tests", "ZOMBIE").toLowerCase().contains("marlin")) {
              zombieThread = true;
              break;
            }
          }
        }

        if (zombieThread || enableTrace) {
          System.out.println("Thread: " + t.getName() + " ID: " + t.getId());
          for (StackTraceElement ste : stes) {
            System.out.println(ste.toString());
          }
        }

        if (zombieThread)
          zombieThreads = true;

      }
    }
    catch (Throwable e1) {
      System.out.println("Exception while printing stacktrace " +  e1);
    }

    return zombieThreads;
  }

  public class OneProducer implements Runnable {
    private String streamName;
    private int numPartitions;
    private KafkaProducer producer;
    private int numMsgsPerPartition;
    private String topicName;
    private byte[] key;
    private byte[] value;
    private boolean success;
    private boolean exceptions;

    public OneProducer(String streamName, int numparts, int numMsgsPerPartition) {
      this.streamName = streamName;
      this.numPartitions = numparts;
      this.numMsgsPerPartition = numMsgsPerPartition;
      this.topicName = streamName+":topicname";
      this.key = new byte[20];
      this.value = new byte[20];
      this.success = false;

      MarlinConfigDefaults cdef = MarlinConfigDefaults.getDefaultInstance();
      Properties props = new Properties();
      props.put("key.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
      props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
      props.put(cdef.getMetadataMaxAge(), 100);  // want to exercise metadata refresher
      this.producer = new KafkaProducer<byte[], byte[]>(props);
    }

    public void run() {
      CountCallback cb = new CountCallback(numMsgsPerPartition * numPartitions);
      for (int i = 0; i < numMsgsPerPartition; i++) {
        for (int j = 0; j < numPartitions; j++) {
          ProducerRecord<byte[], byte[]> record =
            new ProducerRecord<byte[], byte[]>(topicName, j, key, value);
          producer.send(record, cb);
        }
      }
      producer.flush();
      producer.close();
      success = cb.success();
      exceptions = cb.exceptions();
    }

    public boolean success() {
      return success;
    }

    public boolean exceptions() {
      return exceptions;
    }
  }

  public static final class CountCallback implements Callback {
    private int numTotalCallbacks;
    private AtomicInteger numCallbacks;
    private AtomicInteger numExceptions;
    public CountCallback(int numMsgs) {
      this.numTotalCallbacks = numMsgs;
      this.numCallbacks = new AtomicInteger(0);
      this.numExceptions = new AtomicInteger(0);
    }

    public void onCompletion(RecordMetadata metadata, Exception exception) {
      if (exception != null) {
        // System.out.println("Received exception " + exception);
        numExceptions.getAndIncrement();
      }
      int value = numCallbacks.getAndIncrement();

      if (value + 1 == numTotalCallbacks) {
        synchronized(this) {
          try {
            this.notifyAll();
          } catch (Exception e) {
            System.out.println("notify all on completion");
          }
        }
      }
    }

    public boolean exceptions() {
      return (numExceptions.get() > 0);
    }

    public void waitOnCompletion() {
      synchronized(this) {
        try {
          this.wait();
        } catch (Exception e) {
          System.out.println("waiting on completion interrupted");
        }
      }
    }

    public boolean success() {
      return (numCallbacks.get() == numTotalCallbacks);
    }
  }
}
