package com.mapr.baseutils.threadpool;

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.ThreadFactory;

import org.apache.log4j.Logger;

public class GrowingThreadPool extends ThreadPoolExecutor implements HealthCheck 
{
  private ThreadPoolGrowth growth[];
  private String poolName;

  /*
   * used to modify how the thread pool grows. this is a tmp variable. 
   * On healthCheck, the actual growth array is modified
   */
  private ThreadPoolGrowth tmpGrowth[];

  /* initial size of the thread pool. The size cannot be reduced below this value */
  private int minPoolSize;
  /* current index in the growth array used to define how the thread pool grows */
  private int curGrowthIndex;
  /* current thread pool size */
  private int curPoolSize;
  // thread pool request backlog size
  private int backLog;

  /*
   * the last time at which the thread pool was busy or the last time at which
   * the thread pool size was reduced
   */
  private long busySince;

  /* max size of the thread pool */
  private int maxPoolSize;
  private long lastMsgPrintTime;
  private static final Logger LOG = Logger.getLogger(GrowingThreadPool.class);

  public GrowingThreadPool(String name, int numThreads, ThreadPoolGrowth[] growth) {
    super(numThreads /* corePoolSize */, 
        numThreads /* maximumPoolSize */, 
        Long.MAX_VALUE /* keepAliveTime */,
        TimeUnit.SECONDS,
        new LinkedBlockingQueue<Runnable>(),
        new MyThreadFactory(name) /* used to name threads for debugging */
        );

    this.poolName = name;
    this.growth = growth;
    this.tmpGrowth = null;
    this.backLog = 5;

    this.minPoolSize = getCorePoolSize();
    this.curGrowthIndex = 0;
    this.curPoolSize = minPoolSize;
    this.busySince = System.currentTimeMillis();
    this.maxPoolSize = Integer.MAX_VALUE;
    this.lastMsgPrintTime = 0;
  }

  static class MyThreadFactory implements ThreadFactory
  {
    int thrCount;
    String thrPrefix;

    MyThreadFactory(String poolName) {
      thrCount = 0;
      thrPrefix = poolName + "-";
    }

    public Thread newThread(Runnable r) {
      int thrId = 0;
      synchronized (this) {
        thrCount++;
        thrId = thrCount;
      }
      try {
        return new Thread(r, thrPrefix + thrId);
      }
      catch (OutOfMemoryError oom) {
        /* exit, if we are unable to create a thread */
        LOG.fatal("Unable to create new thread, exiting the process");
        System.exit(1);
        return null;
      }
    }
  }

  public void changeGrowthRate(ThreadPoolGrowth[] newGrowth) {
    tmpGrowth = newGrowth;
  }

  public void setMaxPoolSize(int maxSize) {
    maxPoolSize = maxSize;
  }

  public void setBacklog(int val) {
    this.backLog = val;
  }

  public int getNumOfFreeSlots() {
    return maxPoolSize - getActiveCount() + (backLog - getQueue().size());
  }

  public boolean canProcessRequest() {
    /* check if we have reached the limit */
    if (curPoolSize < maxPoolSize) {
      return true;
    }
    /* check if there are idle threads available */
    if (getActiveCount() < curPoolSize) {
      return true;
    }
    /* no idle threads available. allow upto backLog(default=5) waiters in the queue */
    return (getQueue().size() < backLog);
  }

  /*********************************************/
  /** Implementation of HealthCheck Interface **/
  /*********************************************/
  
  @Override
  public void healthCheck() {
    modifyGrowthRate();
    TimeStampedRunnableTask task = (TimeStampedRunnableTask)getQueue().peek();
    long curTime = System.currentTimeMillis();
    if (task != null) {
      busySince = curTime;
      if (needsIncrease(curTime - task.arrTime())) {
        increaseThreadPoolSize();
      }
    }
    else {
      /* task queue is empty */
      if (needsDecrease(curTime - busySince)) {
        busySince = curTime;
        decreaseThreadPoolSize();
      }
    }
  }

  private void modifyGrowthRate() {
    if (tmpGrowth == null) {
      return;
    }
    if (tmpGrowth.length < 2) {
      LOG.warn("modifyGrowthRate: ignoring the new growth rate for "
          + poolName + " since new growth rate array is too small");
      tmpGrowth = null;
      return;
    }

    /* pick the current index. curGrowthIndex is the only thing that changes */
    int i = 0;
    while (curPoolSize >= tmpGrowth[i + 1].getCurSize()) {
      i++;
      if (i == tmpGrowth.length - 1) {
        /* invalid growth rate specified do not change anything */
        LOG.warn("modifyGrowthRate: ignoring the new growth rate for "
            + poolName + " since new growth rate is invalid");
        tmpGrowth = null;
        return;
      }
    }
    growth = tmpGrowth;
    curGrowthIndex = i;
    tmpGrowth = null;
  }

  private boolean needsIncrease(long waitTime)
  {
    if (curPoolSize >= maxPoolSize) {
      long curTime = System.currentTimeMillis();
      if (curTime - lastMsgPrintTime > 15 * 1000) {
        lastMsgPrintTime = curTime;
        if (LOG.isDebugEnabled()) {
          LOG.debug("PoolName: " + poolName + " QueueSize: " + getQueue().size()
              + " WaitTime: " + waitTime + "ms");
        }
      }
      return false;
    }
    return (waitTime >= growth[curGrowthIndex].getIncrementWaitTime());
  }

  private void increaseThreadPoolSize()
  {
    if (curPoolSize >= maxPoolSize) {
      return;
    }
    int incrementBy = growth[curGrowthIndex].getIncrement();
    int waitTime = growth[curGrowthIndex].getIncrementWaitTime();
    curPoolSize += incrementBy;
    if (curPoolSize > maxPoolSize) {
      curPoolSize = maxPoolSize;
    }
    /* figure out the index in the growth table */
    while (curPoolSize >= growth[curGrowthIndex + 1].getCurSize()) {
      curGrowthIndex++;
    }
    if (LOG.isInfoEnabled()) {
      LOG.info("Increasing pool size for " + poolName + " to " + curPoolSize 
          + ", after waiting " + waitTime + "ms");
    }
    /* reduce max pool size first */
    super.setMaximumPoolSize(curPoolSize);
    super.setCorePoolSize(curPoolSize);
  }

  private boolean needsDecrease(long waitTime) {
    return ((curPoolSize > minPoolSize) &&
        (waitTime >= growth[curGrowthIndex].getDecrementWaitTime()));
  }

  private void decreaseThreadPoolSize()
  {
    int decrementBy = growth[curGrowthIndex].getDecrement();
    int waitTime = growth[curGrowthIndex].getDecrementWaitTime();
    int newGrowthIndex = curGrowthIndex;
    int newPoolSize = curPoolSize - decrementBy;
    if (newPoolSize < minPoolSize) {
      newPoolSize = minPoolSize;
    }
    while (newPoolSize < growth[newGrowthIndex].getCurSize()) {
      if (newGrowthIndex == 0) {
        /* the newPoolSize < size at 0th index */
        LOG.warn("decreaseThreadPoolSize: cannot reduce pool size for "
            + poolName + " since new thread pool size is too small");
        return;
      }
      newGrowthIndex--;
    }
    curPoolSize = newPoolSize;
    curGrowthIndex = newGrowthIndex;
    if (LOG.isInfoEnabled()) {
      LOG.info("Reducing pool size for " + poolName  + " to " + curPoolSize 
          + ", after waiting " + waitTime  + "ms");
    }
    /* reduce core pool size first */
    super.setCorePoolSize(curPoolSize);
    super.setMaximumPoolSize(curPoolSize);
  }
}
