/* Copyright (c) 2015 & onwards. MapR Tech, Inc., All rights reserved */
package com.mapr.streams.tests.listener;

import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.*;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.regex.Pattern;

import org.apache.hadoop.conf.Configuration;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.errors.WakeupException;
import org.apache.kafka.clients.producer.KafkaProducer;

import org.apache.kafka.common.serialization.ByteArrayDeserializer;

import com.mapr.tests.BaseTest;
import com.mapr.tests.annotations.ClusterTest;

import com.mapr.streams.Admin;
import com.mapr.streams.Streams;
import com.mapr.streams.StreamDescriptor;
import com.mapr.fs.proto.Marlinserver.MarlinConfigDefaults;

import com.mapr.streams.impl.listener.MarlinListener;
import com.mapr.streams.impl.listener.MarlinListener.MarlinJoinCallback;
import com.mapr.fs.proto.Marlinserver.*;
import java.util.concurrent.locks.*;
import java.util.concurrent.TimeUnit;
import com.google.protobuf.ByteString;
import org.apache.kafka.clients.consumer.*;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.mapr.GenericHFactory;

  
@Category(ClusterTest.class)
public class ListenerJoin extends BaseTest {
  private static final Logger _logger = LoggerFactory.getLogger(ListenerJoin.class);
  private static final String STREAM = "/jtest-" + ListenerJoin.class.getSimpleName();
  private static Admin madmin;
  private static final int numPartitions = 10;

  @BeforeClass
  public static void setupTestClass() throws Exception {
    final Configuration conf = new Configuration();
    madmin = Streams.newAdmin(conf);

    //Cleanup all stale streams
    try {
      madmin.deleteStream(STREAM);
    } catch (Exception e) {}

    StreamDescriptor sdesc = Streams.newStreamDescriptor();
    sdesc.setDefaultPartitions(numPartitions);
    madmin.createStream(STREAM, sdesc);
  }

  @AfterClass
  public static void cleanupTestClass() throws Exception {
    madmin.deleteStream(STREAM);
  }

  @Test
  public void testJoin() throws IOException {
    Thread[] threads = new Thread[5];
    int i;
    for (i = 0; i < 3; i++) {
      threads[i] = new Thread(new Joiner(false, i));
      threads[i].start();
    }
    // start a joiner that leaves abruptly
    threads[i] = new Thread(new Joiner(true, i));
    threads[i].start();
    //start another joiner after a delay.
    threads[i + 1] = new Thread(new Joiner(false, i + 1));
    try {
      Thread.sleep(5000);
    } catch (InterruptedException e) {
      System.out.println("interrupted");
    }
    threads[i + 1].start();

    try {
      for (i = 0; i < 5; i++) {
        threads[i].join();
      }
    } catch (InterruptedException e) {
      e.printStackTrace();
      return;
    }
  }
  public class Joiner implements Runnable {
    private Properties props;
    private MarlinListener listener;
    public JoinTestCallback cb = new JoinTestCallback();
    final Lock lock = new ReentrantLock();
    final Condition condition = lock.newCondition();
    boolean joinComplete;
    boolean rejoinNeeded;

    private JoinGroupDesc desc;
    private WorkerState ws;
    ByteString bstrMetadata;
    JoinGroupResponse resp;
    // map from MemberId to WorkerState
    private Map<String, WorkerState> wsMap;
    JoinGroupInfo joinInfo;

    //sync
    KafkaProducer producer;
    KafkaConsumer consumer;
    final String groupId = "testgroup";
    final String syncTopic = STREAM + ":__mapr__" + groupId + "_assignment";

    boolean leaveEarly;

    // l: whether to leave early to simulate an abrupt leave
    // i: index identifying the joiner
    public Joiner(boolean l, int i) {
      props = new Properties();
      props.put("group.id", groupId);
      props.put("key.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer");
      props.put("value.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer");
      props.put("key.serializer", "org.apache.kafka.common.serialization.LongSerializer");
      props.put("value.serializer", "org.apache.kafka.common.serialization.ByteArraySerializer");
      props.put("auto.offset.reset", "earliest");
      props.put("streams.consumer.default.stream", STREAM);
      GenericHFactory<ConsumerConfig> configFactory = new GenericHFactory<ConsumerConfig>();
      ConsumerConfig config = configFactory.getImplementorInstance("org.apache.kafka.clients.consumer.ConsumerConfig",
                                           new Object[] {props},
                                           new Class[] {Map.class});
      listener = new MarlinListener(config, null, null);
      joinComplete = false;
      rejoinNeeded = false;
      ws = WorkerState.newBuilder().setUrl("worker_url_" + i).setOffset(1000).build();
      bstrMetadata = ws.toByteString();
      desc = JoinGroupDesc.newBuilder().setProtocolType("connect").
                          addMemberProtocols(MemberProtocol.newBuilder().
                          setProtocol("someprotocol").
                          setMemberMetadata(bstrMetadata).build()).build();
      leaveEarly = l;
      props.remove("group.id");
      props.put("key.deserializer", "org.apache.kafka.common.serialization.LongDeserializer");
      consumer = new KafkaConsumer<Long, byte[]>(props);
    }

    @Override
    public void run() {
      while (true) {
        joinComplete = false;
        rejoinNeeded = false;
        resp = listener.join(desc, cb);
        System.out.println("memberid from join API is " + resp.getMemberId());
        desc = JoinGroupDesc.newBuilder().mergeFrom(desc).
                             setMemberId(resp.getMemberId()).build();

        if (leaveEarly) {
          try {
            //leave some time after join is complete.
            Thread.sleep(10000);
            System.out.println("leaver memberid is " + resp.getMemberId());
            listener.close();
          } catch (InterruptedException e) {
            e.printStackTrace();
          }
          return;
        }

        try {
          lock.lock();
          while (joinComplete == false && rejoinNeeded == false) {
            boolean continueWaiting = condition.await(60L, TimeUnit.SECONDS);
            if (continueWaiting == false) {
              System.out.println("join/rejoin time expired");
              return;
            }
          }
        } catch (InterruptedException e) {
          e.printStackTrace();
          System.out.println("interrupted");
        } finally {
          lock.unlock();
        }

        //if you are a leader, produce group assignment.
        //as a consumer subscribe and poll for receiving the
        //assignment.
        //if poll times out, check if there is a reconfig
        //(rejoinNeeded) and go back to join if there is a reconfig
        //if poll returns an assignment, check if it matches the
        //generation id...
        //if no match, go back to poll...
        //if there is a match, accept the assignment and proceed to
        //wait on a reconfig event (for now)
        if (joinComplete == true) {
          int i = 0;
          while (true) {
            consumer.subscribe(Arrays.asList(syncTopic));
            ConsumerRecords<Long, byte[]> records = consumer.poll(15000);
            //assertTrue(records.isEmpty());
            System.out.println("past consumer poll" + i);
            Long lastSeen = 0L;
            for (ConsumerRecord<Long, byte[]> record : records) {
              System.out.println("consumer record..generation ID " + record.key());
              lastSeen = record.key();
            }
            lock.lock();
            if (rejoinNeeded) {
              lock.unlock();
              break;
            }
            lock.unlock();
            if (joinInfo.getGroupGenerationId() == lastSeen)
              break;
            i++;
          }

          try {
            lock.lock();
            while (rejoinNeeded == false) {
              boolean continueWaiting = condition.await(60L, TimeUnit.SECONDS);
              if (continueWaiting == false) {
                System.out.println("rejoin time expired");
                return;
              }
            }
          } catch (InterruptedException e) {
            System.out.println("interrupted");
          } finally {
            lock.unlock();
          }
        }
      }
    }

    public class JoinTestCallback implements MarlinJoinCallback {
      @Override 
      public void onJoin(JoinGroupInfo jgi) {
        if (jgi.getGroupLeaderId().equalsIgnoreCase(resp.getMemberId())) {
          try {
            System.out.println("leader id is " + jgi.getGroupLeaderId());
            wsMap = new HashMap<String, WorkerState>();
            for (int i = 0; i < jgi.getMembersCount(); i++) {
              wsMap.put(jgi.getMembers(i).getMemberId(),
                        WorkerState.newBuilder().
                                    mergeFrom(jgi.getMembers(i).getMemberMetadata()).build());
            }

            producer = new KafkaProducer<Long, byte[]>(props);
            GroupAssignment.Builder gaBuilder  = GroupAssignment.newBuilder().
                                                  setGroupGenerationId(jgi.getGroupGenerationId());
            String leaderUrl = ws.getUrl();
            for (Map.Entry<String, WorkerState> e : wsMap.entrySet()) {
              WorkerAssignment wa = WorkerAssignment.newBuilder().setLeaderURL(leaderUrl).
                                                     setLeaderId(jgi.getGroupLeaderId()).build();
              MemberState ms = MemberState.newBuilder().setMemberId(e.getKey()).
                                           setMemberAssignment(wa.toByteString()).build();
              gaBuilder.addMemberState(ms);
              System.out.println("Member " + e.getKey() + " url is " + 
                                 e.getValue().getUrl() + " offset is " +
                                 e.getValue().getOffset());
            }
            GroupAssignment ga = gaBuilder.build();
            System.out.println("producing generation id " + jgi.getGroupGenerationId());
            producer.send(new ProducerRecord<Long, byte[]>(syncTopic, jgi.getGroupGenerationId(), ga.toByteArray()));
            producer.close();
          } catch (Exception e) {
            e.printStackTrace();
          }
        }
        joinInfo = jgi;
        lock.lock();
        joinComplete = true;
        condition.signal();
        lock.unlock();
      }

      @Override
      public void onRejoin(JoinGroupInfo jgi) {
        lock.lock();
        rejoinNeeded = true;
        condition.signal();
        lock.unlock();
      }
    }
  }
}
