package com.mapr.db.mapreduce.impl;

import static org.ojai.DocumentConstants.ID_FIELD;
import static org.ojai.store.QueryCondition.Op.GREATER_OR_EQUAL;
import static org.ojai.store.QueryCondition.Op.LESS;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.TreeSet;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.ojai.Value;
import org.ojai.Value.Type;
import org.ojai.store.QueryCondition;

import com.mapr.db.MapRDB;
import com.mapr.db.TabletInfo;
import com.mapr.db.impl.ConditionNode.RowkeyRange;
import com.mapr.db.impl.IdCodec;
import com.mapr.db.mapreduce.TableInputFormat;
import com.mapr.org.apache.hadoop.hbase.util.Bytes;

public class RangeChecksumInputFormat extends TableInputFormat {

  public final static String SPLITFILENAME = "splitfilename";
  public final static String INCLUDEDREGIONFILENAME = "includedregionfilename";

  private static final Log LOG = LogFactory.getLog(RangeChecksumInputFormat.class);
  private String splitFileName = null;
  private String includedRegionStartPointsFileName = null;
  private List<ByteBufWritableComparable> splitPoints = null;
  private List<ByteBufWritableComparable> includedRegionStartPoints = null;
  private TreeSet<ByteBufWritableComparable> searchableSplitPoints = null;
  private TreeSet<ByteBufWritableComparable> searchableIncludedRegionStartPoints = null;

  @Override
  public void setConf(Configuration configuration) {
    splitFileName = configuration.get(SPLITFILENAME, null);
    if (splitFileName != null) {
    try {
        Path splitFilePath = new Path(splitFileName);
        splitPoints = DiffTableUtils.readKeyRange(configuration, splitFilePath);
      } catch (IOException ie) {
        throw new IllegalArgumentException("can't read  file " + splitFileName + " ", ie);
      }
    }

    includedRegionStartPointsFileName = configuration.get(INCLUDEDREGIONFILENAME, null);
    if (includedRegionStartPointsFileName != null) {
      try {
        Path includedRegionStartPointsFilePath = new Path(includedRegionStartPointsFileName);
        includedRegionStartPoints = DiffTableUtils.readKeyRange(configuration, includedRegionStartPointsFilePath);
      } catch (IOException ie) {
        throw new IllegalArgumentException("can't read  file " + includedRegionStartPointsFileName + " ", ie);
      }
    }
    super.setConf(configuration);
  }

  // Given a row key, find which split this key belongs to, and return the start
  // row key of that split.
  public ByteBufWritableComparable getSplitStartKey(ByteBufWritableComparable key) {
    if (splitPoints == null) {
      throw new IllegalArgumentException("splitPoints are null from file " + splitFileName);
    }
    if (searchableSplitPoints == null) {
      searchableSplitPoints = new TreeSet<ByteBufWritableComparable>(splitPoints);
    }
    ByteBufWritableComparable ret = searchableSplitPoints.floor(key);
    if (ret == null) {
      ret = splitPoints.get(0);
      LOG.warn("key " +  key.toString() + " does not belong to a range. " +
                "Return first range start key " + ret.toString());
    }
    return ret;
  }

  // Whether this split should be included in the mapreduce job
  protected boolean includeRegionInSplit(final byte[] startKey, final byte [] endKey) {
    // No specific start points given as the ranges that need to be included, include every range.
    if (includedRegionStartPoints == null) {
      return true;
    }
    if (searchableIncludedRegionStartPoints == null) {
      searchableIncludedRegionStartPoints = new TreeSet<ByteBufWritableComparable>(includedRegionStartPoints);
    }
    ByteBufWritableComparable skey = new ByteBufWritableComparable(ByteBuffer.wrap(startKey));
    boolean ret = searchableIncludedRegionStartPoints.contains(skey);
    if (LOG.isDebugEnabled()) {
      LOG.debug("startKey = (" +  startKey + "), endKey = (" + endKey + ")," +
               " searchableIncludedRegionStartPoints contains startKey = (" +  Boolean.toString(ret) + ")");
    }
    return ret;
  }

  //return the splits for the json table
  @Override
  public List<InputSplit> getSplits(JobContext context)
      throws IOException, InterruptedException {
    if (splitPoints == null) {
      return super.getSplits(context);
    } else {

      List<RowkeyRange> keys = DiffTableUtils.GenStartEndKeys(splitPoints);
      LOG.debug("keyrange number ="+ Integer.toString(keys.size()));
      DiffTableUtils.logKeyRanges(keys);

      /* get TabletInfo handle from table */
      TabletInfo[] tablets = jTable.getTabletInfos();
      String[] dummyLocs = new String[1];
      dummyLocs[0] = new String("dummy host");
      if (keys == null || keys.isEmpty()) {
          List<InputSplit> splits = new ArrayList<InputSplit>(1);

          TableSplit split = new TableSplit(jTable.getName(), super.cond, dummyLocs /*host locations*/, 0 /*estimated size*/);
          splits.add(split);
          LOG.debug("getSplits: split -> 0 -> " + split);
          return splits;
      }

      int i = 0;
      // The length of the range list is 1.
      byte[] startRow = new byte[0];
      byte[] stopRow = new byte[0];

      if (super.cond != null) {
          List<RowkeyRange> tableKeyRange = super.cond.getRowkeyRanges();
          startRow = tableKeyRange.get(0).getStartRow();
          stopRow = tableKeyRange.get(0).getStopRow();
      }

      List<InputSplit> splits = new ArrayList<InputSplit>(keys.size());
      for (ListIterator<RowkeyRange> iter = keys.listIterator(0); iter.hasNext(); ) {
        RowkeyRange k = iter.next();
        if ( !includeRegionInSplit(k.getStartRow(), k.getStopRow())) {
          LOG.debug("Range " + i + " : " + k.toString() +" is NOT included");
          continue;
        }
        LOG.debug("Range " + i + " : " + k.toString() +" is included");

        // determine if the given start an stop key fall into the region
        if ((startRow.length == 0 || k.getStopRow().length == 0 ||
            Bytes.compareTo(startRow, k.getStopRow()) < 0) &&
            (stopRow.length == 0 ||
             Bytes.compareTo(stopRow, k.getStartRow()) > 0)) {
          byte[] splitStart = startRow.length == 0 ||
            Bytes.compareTo(k.getStartRow(), startRow) >= 0 ?
              k.getStartRow() : startRow;
          byte[] splitStop = (stopRow.length == 0 ||
            Bytes.compareTo(k.getStopRow(), stopRow) <= 0) &&
            k.getStopRow().length > 0 ?
              k.getStopRow() : stopRow;

          QueryCondition cond = getCond(splitStart, splitStop);
          TableSplit split = new TableSplit(jTable.getName(), cond, dummyLocs /*host location*/, (long)0 /*estimated size*/);
          splits.add(split);
          LOG.debug("getSplits: split -> " + i + " -> " + split);
          ++i;
        }
      }
      return splits;
    }
  }

  //Set the condition start row.
  private QueryCondition addCondStarRow(QueryCondition cond, byte[] splitStart) {
    if (splitStart == null) {
      return cond;
    }

    Value splitStartValue = IdCodec.decode(splitStart);
    Type startType = splitStartValue.getType();
    if (startType == Type.STRING) {
      LOG.debug("decoded splitStart = (" + splitStartValue.getString()+ ")");
      cond.is(ID_FIELD, GREATER_OR_EQUAL, splitStartValue.getString());
    } else if (startType == Type.BINARY ) {
      LOG.debug("decoded splitStart = (" + Bytes.toStringBinary(splitStartValue.getBinary())+ ")");
      cond.is(ID_FIELD, GREATER_OR_EQUAL, splitStartValue.getBinary());
    } else if (startType == Type.NULL ){
      LOG.debug("decoded splitStart = (" + Bytes.toStringBinary(splitStart) + ")");
      cond.is(ID_FIELD, GREATER_OR_EQUAL, ByteBuffer.wrap(splitStart));
    } else {
      throw new IllegalArgumentException("type of split start is neither binary or string, instead it is " + startType);
    }
    return cond;
  }

  //Set the condition stop row.
  private QueryCondition addCondStopRow(QueryCondition cond, byte[] splitStop) {
    if (splitStop == null) {
      return cond;
    }

    Value splitStopValue = IdCodec.decode(splitStop);
    Type stopType = splitStopValue.getType();
    if (stopType == Type.STRING) {
      LOG.debug("decoded splitStop = (" + splitStopValue.getString()+ ")");
      cond.is(ID_FIELD, LESS, splitStopValue.getString());
    } else if (stopType == Type.BINARY) {
      LOG.debug("decoded splitStop = (" + Bytes.toStringBinary(splitStopValue.getBinary())+ ")");
      cond.is(ID_FIELD, LESS, splitStopValue.getBinary());
    } else if (stopType == Type.NULL ){
      LOG.debug("decoded splitStop = (" + Bytes.toStringBinary(splitStop) + ")");
      cond.is(ID_FIELD, GREATER_OR_EQUAL, ByteBuffer.wrap(splitStop));
    } else {
      throw new IllegalArgumentException("type of split stop is neither binary or string, instead it is " + stopType);
    }
    return cond;
  }

  //Create a simple condition with start/stop row.
  private QueryCondition getCond(byte[] splitStart, byte[] splitStop) {

    QueryCondition cond = null;
    if (splitStart.length > 0 && splitStop.length > 0) {
      cond = MapRDB.newCondition().and();
      cond = addCondStarRow(cond, splitStart);
      cond = addCondStopRow(cond, splitStop);
      cond.close().build();
    } else if (splitStop.length > 0) {
      cond = MapRDB.newCondition();
      cond = addCondStopRow(cond, splitStop);
      cond.build();
    } else if (splitStart.length > 0) {
      cond = MapRDB.newCondition();
      cond = addCondStarRow(cond, splitStart);
      cond.build();
    } else {
      cond = MapRDB.newCondition().build();
    }
    return cond;
  }
}

