/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.MuxOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.UnionOperator;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.optimizer.BucketMapjoinProc;
import org.apache.hadoop.hive.ql.optimizer.MapJoinProcessor;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OpTraits;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.Statistics;
import org.apache.hadoop.hive.serde2.binarysortable.BinarySortableSerDe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparkMapJoinOptimizer
implements NodeProcessor {
    private static final Logger LOG = LoggerFactory.getLogger((String)SparkMapJoinOptimizer.class.getName());

    @Override
    public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
        OptimizeSparkProcContext context = (OptimizeSparkProcContext)procCtx;
        HiveConf conf = context.getConf();
        JoinOperator joinOp = (JoinOperator)nd;
        if (!conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN)) {
            return null;
        }
        LOG.info("Check if it can be converted to map join");
        long[] mapJoinInfo = this.getMapJoinConversionInfo(joinOp, context);
        int mapJoinConversionPos = (int)mapJoinInfo[0];
        if (mapJoinConversionPos < 0) {
            return null;
        }
        int numBuckets = -1;
        List<List<String>> bucketColNames = null;
        LOG.info("Convert to non-bucketed map join");
        MapJoinOperator mapJoinOp = this.convertJoinMapJoin(joinOp, context, mapJoinConversionPos);
        if (conf.getBoolVar(HiveConf.ConfVars.HIVE_VECTORIZATION_MAPJOIN_NATIVE_ENABLED) && conf.getBoolVar(HiveConf.ConfVars.HIVE_VECTORIZATION_ENABLED)) {
            ((MapJoinDesc)mapJoinOp.getConf()).getKeyTblDesc().getProperties().setProperty("serialization.lib", BinarySortableSerDe.class.getName());
        }
        if (conf.getBoolVar(HiveConf.ConfVars.HIVEOPTBUCKETMAPJOIN)) {
            LOG.info("Check if it can be converted to bucketed map join");
            numBuckets = this.convertJoinBucketMapJoin(joinOp, mapJoinOp, context, mapJoinConversionPos);
            if (numBuckets > 1) {
                LOG.info("Converted to map join with " + numBuckets + " buckets");
                bucketColNames = joinOp.getOpTraits().getBucketColNames();
                mapJoinInfo[2] = mapJoinInfo[2] / (long)numBuckets;
            } else {
                LOG.info("Can not convert to bucketed map join");
            }
        }
        OpTraits opTraits = new OpTraits(bucketColNames, numBuckets, null);
        mapJoinOp.setOpTraits(opTraits);
        mapJoinOp.setStatistics(joinOp.getStatistics());
        this.setNumberOfBucketsOnChildren(mapJoinOp);
        context.getMjOpSizes().put(mapJoinOp, mapJoinInfo[1] + mapJoinInfo[2]);
        return mapJoinOp;
    }

    private void setNumberOfBucketsOnChildren(Operator<? extends OperatorDesc> currentOp) {
        int numBuckets = currentOp.getOpTraits().getNumBuckets();
        for (Operator<OperatorDesc> op : currentOp.getChildOperators()) {
            if (op instanceof ReduceSinkOperator || op instanceof GroupByOperator) continue;
            op.getOpTraits().setNumBuckets(numBuckets);
            if (numBuckets < 0) {
                op.getOpTraits().setBucketColNames(null);
            }
            this.setNumberOfBucketsOnChildren(op);
        }
    }

    private int convertJoinBucketMapJoin(JoinOperator joinOp, MapJoinOperator mapJoinOp, OptimizeSparkProcContext context, int bigTablePosition) throws SemanticException {
        ParseContext parseContext = context.getParseContext();
        ArrayList<String> joinAliases = new ArrayList<String>();
        String baseBigAlias = null;
        Map<Integer, Set<String>> posToAliasMap = joinOp.getPosToAliasMap();
        for (Map.Entry<Integer, Set<String>> entry : posToAliasMap.entrySet()) {
            if (entry.getKey() == bigTablePosition) {
                baseBigAlias = entry.getValue().iterator().next();
            }
            for (String alias : entry.getValue()) {
                if (joinAliases.contains(alias)) continue;
                joinAliases.add(alias);
            }
        }
        mapJoinOp.setPosToAliasMap(posToAliasMap);
        BucketMapjoinProc.checkAndConvertBucketMapJoin(parseContext, mapJoinOp, baseBigAlias, joinAliases);
        MapJoinDesc joinDesc = (MapJoinDesc)mapJoinOp.getConf();
        return joinDesc.isBucketMapJoin() ? joinDesc.getBigTableBucketNumMapping().size() : -1;
    }

    private long[] getMapJoinConversionInfo(JoinOperator joinOp, OptimizeSparkProcContext context) {
        Set<Integer> bigTableCandidateSet = MapJoinProcessor.getBigTableCandidates(((JoinDesc)joinOp.getConf()).getConds());
        long maxSize = context.getConf().getLongVar(HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD);
        int bigTablePosition = -1;
        Statistics bigInputStat = null;
        long totalSize = 0L;
        int pos = 0;
        boolean bigTableFound = false;
        for (Operator<OperatorDesc> parentOp : joinOp.getParentOperators()) {
            Statistics currInputStat = parentOp.getStatistics();
            if (currInputStat == null) {
                LOG.warn("Couldn't get statistics from: " + parentOp);
                return new long[]{-1L, 0L, 0L};
            }
            if (this.containUnionWithoutRS(parentOp.getParentOperators().get(0))) {
                return new long[]{-1L, 0L, 0L};
            }
            long inputSize = currInputStat.getDataSize();
            if (bigInputStat == null || bigInputStat != null && inputSize > bigInputStat.getDataSize()) {
                if (bigTableFound) {
                    return new long[]{-1L, 0L, 0L};
                }
                if (inputSize > maxSize) {
                    if (!bigTableCandidateSet.contains(pos)) {
                        return new long[]{-1L, 0L, 0L};
                    }
                    bigTableFound = true;
                }
                if (bigInputStat != null) {
                    totalSize += bigInputStat.getDataSize();
                }
                if (totalSize > maxSize) {
                    return new long[]{-1L, 0L, 0L};
                }
                if (bigTableCandidateSet.contains(pos)) {
                    bigTablePosition = pos;
                    bigInputStat = currInputStat;
                }
            } else if ((totalSize += currInputStat.getDataSize()) > maxSize) {
                return new long[]{-1L, 0L, 0L};
            }
            ++pos;
        }
        if (bigTablePosition == -1) {
            return new long[]{-1L, 0L, 0L};
        }
        long connectedMapJoinSize = this.getConnectedMapJoinSize(joinOp.getParentOperators().get(bigTablePosition), joinOp, context);
        if (connectedMapJoinSize + totalSize > maxSize) {
            return new long[]{-1L, 0L, 0L};
        }
        return new long[]{bigTablePosition, connectedMapJoinSize, totalSize};
    }

    private long getConnectedMapJoinSize(Operator<? extends OperatorDesc> parentOp, Operator joinOp, OptimizeSparkProcContext ctx) {
        long result = 0L;
        for (Operator<OperatorDesc> grandParentOp : parentOp.getParentOperators()) {
            result += this.getConnectedParentMapJoinSize(grandParentOp, ctx);
        }
        return result += this.getConnectedChildMapJoinSize(joinOp, ctx);
    }

    private long getConnectedParentMapJoinSize(Operator<? extends OperatorDesc> op, OptimizeSparkProcContext ctx) {
        if (op instanceof UnionOperator || op instanceof ReduceSinkOperator) {
            return 0L;
        }
        if (op instanceof MapJoinOperator) {
            long mjSize = ctx.getMjOpSizes().get(op);
            return mjSize;
        }
        long result = 0L;
        for (Operator<OperatorDesc> parentOp : op.getParentOperators()) {
            result += this.getConnectedParentMapJoinSize(parentOp, ctx);
        }
        return result;
    }

    private long getConnectedChildMapJoinSize(Operator<? extends OperatorDesc> op, OptimizeSparkProcContext ctx) {
        if (op instanceof UnionOperator || op instanceof ReduceSinkOperator) {
            return 0L;
        }
        if (op instanceof MapJoinOperator) {
            long mjSize = ctx.getMjOpSizes().get(op);
            return mjSize;
        }
        long result = 0L;
        for (Operator<OperatorDesc> childOp : op.getChildOperators()) {
            result += this.getConnectedChildMapJoinSize(childOp, ctx);
        }
        return result;
    }

    public MapJoinOperator convertJoinMapJoin(JoinOperator joinOp, OptimizeSparkProcContext context, int bigTablePosition) throws SemanticException {
        for (Operator<OperatorDesc> parentOp : joinOp.getParentOperators()) {
            if (!(parentOp instanceof MuxOperator)) continue;
            return null;
        }
        MapJoinOperator mapJoinOp = MapJoinProcessor.convertJoinOpMapJoinOp(context.getConf(), joinOp, ((JoinDesc)joinOp.getConf()).isLeftInputJoin(), ((JoinDesc)joinOp.getConf()).getBaseSrc(), ((JoinDesc)joinOp.getConf()).getMapAliases(), bigTablePosition, true);
        Operator<OperatorDesc> parentBigTableOp = mapJoinOp.getParentOperators().get(bigTablePosition);
        if (parentBigTableOp instanceof ReduceSinkOperator) {
            mapJoinOp.getParentOperators().remove(bigTablePosition);
            if (!mapJoinOp.getParentOperators().contains(parentBigTableOp.getParentOperators().get(0))) {
                mapJoinOp.getParentOperators().add(bigTablePosition, parentBigTableOp.getParentOperators().get(0));
            }
            parentBigTableOp.getParentOperators().get(0).removeChild(parentBigTableOp);
            for (Operator<OperatorDesc> op : mapJoinOp.getParentOperators()) {
                if (!op.getChildOperators().contains(mapJoinOp)) {
                    op.getChildOperators().add(mapJoinOp);
                }
                op.getChildOperators().remove(joinOp);
            }
        }
        ((MapJoinDesc)mapJoinOp.getConf()).setQBJoinTreeProps((JoinDesc)joinOp.getConf());
        return mapJoinOp;
    }

    private boolean containUnionWithoutRS(Operator<? extends OperatorDesc> op) {
        boolean result = false;
        if (op instanceof UnionOperator) {
            for (Operator<OperatorDesc> pop : op.getParentOperators()) {
                if (pop instanceof ReduceSinkOperator) continue;
                result = true;
                break;
            }
        } else if (op instanceof ReduceSinkOperator) {
            result = false;
        } else {
            for (Operator<OperatorDesc> pop : op.getParentOperators()) {
                if (!this.containUnionWithoutRS(pop)) continue;
                result = true;
                break;
            }
        }
        return result;
    }
}

