/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;

public class HiveJoinToMultiJoinRule
extends RelOptRule {
    public static final HiveJoinToMultiJoinRule INSTANCE = new HiveJoinToMultiJoinRule(HiveJoin.class, HiveProject.DEFAULT_PROJECT_FACTORY);
    private final RelFactories.ProjectFactory projectFactory;
    private static final transient Log LOG = LogFactory.getLog(HiveJoinToMultiJoinRule.class);

    public HiveJoinToMultiJoinRule(Class<? extends Join> clazz, RelFactories.ProjectFactory projectFactory) {
        super(HiveJoinToMultiJoinRule.operand(clazz, (RelOptRuleOperand)HiveJoinToMultiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinToMultiJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveJoinToMultiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinToMultiJoinRule.any())}));
        this.projectFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall call) {
        Join newJoin;
        RelNode right;
        RelNode left;
        HiveJoin join = (HiveJoin)call.rel(0);
        RelNode multiJoin = HiveJoinToMultiJoinRule.mergeJoin(join, left = call.rel(1), right = call.rel(2));
        if (multiJoin != null) {
            call.transformTo(multiJoin);
            return;
        }
        RelNode swapped = JoinCommuteRule.swap((Join)join, (boolean)true);
        assert (swapped != null);
        Project topProject = null;
        if (swapped instanceof Join) {
            newJoin = (Join)swapped;
        } else {
            topProject = (Project)swapped;
            newJoin = (Join)swapped.getInput(0);
        }
        multiJoin = HiveJoinToMultiJoinRule.mergeJoin(newJoin, right, left);
        if (multiJoin != null) {
            if (topProject != null) {
                multiJoin = this.projectFactory.createProject(multiJoin, topProject.getChildExps(), topProject.getRowType().getFieldNames());
            }
            call.transformTo(multiJoin);
            return;
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static RelNode mergeJoin(Join join, RelNode left, RelNode right) {
        RexNode otherCondition;
        int i;
        boolean combinable;
        List<JoinRelType> leftJoinTypes;
        List<Pair<Integer, Integer>> leftJoinInputs;
        RexNode leftCondition;
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ArrayList<RelNode> newInputs = Lists.newArrayList();
        ArrayList<RexNode> newJoinFilters = Lists.newArrayList();
        newJoinFilters.add(join.getCondition());
        ArrayList<Pair> joinSpecs = Lists.newArrayList();
        if (!(left instanceof Join)) {
            if (!(left instanceof HiveMultiJoin)) return null;
        }
        if (left instanceof Join) {
            Join hj = (Join)left;
            leftCondition = hj.getCondition();
            leftJoinInputs = ImmutableList.of(Pair.of((Object)0, (Object)1));
            leftJoinTypes = ImmutableList.of(hj.getJoinType());
        } else {
            HiveMultiJoin hmj = (HiveMultiJoin)left;
            leftCondition = hmj.getCondition();
            leftJoinInputs = hmj.getJoinInputs();
            leftJoinTypes = hmj.getJoinTypes();
        }
        try {
            combinable = HiveJoinToMultiJoinRule.isCombinablePredicate(join, join.getCondition(), leftCondition);
        }
        catch (CalciteSemanticException e) {
            LOG.trace((Object)"Failed to merge joins", (Throwable)e);
            return null;
        }
        if (!combinable) return null;
        newJoinFilters.add(leftCondition);
        for (i = 0; i < leftJoinInputs.size(); ++i) {
            joinSpecs.add(Pair.of(leftJoinInputs.get(i), (Object)leftJoinTypes.get(i)));
        }
        newInputs.addAll(left.getInputs());
        int numberLeftInputs = newInputs.size();
        newInputs.add(right);
        if (newJoinFilters.size() == 1) {
            return null;
        }
        ImmutableList<RelDataTypeField> systemFieldList = ImmutableList.of();
        ArrayList<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>();
        ArrayList<Integer> filterNulls = new ArrayList<Integer>();
        for (i = 0; i < newInputs.size(); ++i) {
            joinKeyExprs.add(new ArrayList());
        }
        try {
            otherCondition = HiveRelOptUtil.splitHiveJoinCondition(systemFieldList, newInputs, join.getCondition(), joinKeyExprs, filterNulls, null);
        }
        catch (CalciteSemanticException e) {
            LOG.trace((Object)"Failed to merge joins", (Throwable)e);
            return null;
        }
        if (!otherCondition.isAlwaysTrue()) {
            return null;
        }
        ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder();
        for (int i2 = 0; i2 < newInputs.size(); ++i2) {
            List partialCondition = (List)joinKeyExprs.get(i2);
            if (partialCondition.isEmpty()) continue;
            keysInInputsBuilder.set(i2);
        }
        ImmutableBitSet keysInInputs = keysInInputsBuilder.build();
        ImmutableBitSet leftReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range((int)numberLeftInputs));
        ImmutableBitSet rightReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range((int)numberLeftInputs, (int)newInputs.size()));
        if (join.getJoinType() != JoinRelType.INNER) {
            if (leftReferencedInputs.cardinality() > 1) return null;
            if (rightReferencedInputs.cardinality() > 1) {
                return null;
            }
        }
        if (join.getJoinType() != JoinRelType.INNER) {
            int leftInput = keysInInputs.nextSetBit(0);
            int rightInput = keysInInputs.nextSetBit(numberLeftInputs);
            joinSpecs.add(Pair.of((Object)Pair.of((Object)leftInput, (Object)rightInput), (Object)join.getJoinType()));
        } else {
            Iterator i$ = leftReferencedInputs.iterator();
            while (i$.hasNext()) {
                int i3 = (Integer)i$.next();
                Iterator i$2 = rightReferencedInputs.iterator();
                while (i$2.hasNext()) {
                    int j = (Integer)i$2.next();
                    joinSpecs.add(Pair.of((Object)Pair.of((Object)i3, (Object)j), (Object)join.getJoinType()));
                }
            }
        }
        RexNode newCondition = RexUtil.flatten((RexBuilder)rexBuilder, (RexNode)RexUtil.composeConjunction((RexBuilder)rexBuilder, newJoinFilters, (boolean)false));
        return new HiveMultiJoin(join.getCluster(), newInputs, newCondition, join.getRowType(), Pair.left(joinSpecs), Pair.right(joinSpecs));
    }

    private static boolean isCombinablePredicate(Join join, RexNode condition, RexNode otherCondition) throws CalciteSemanticException {
        HiveCalciteUtil.JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, condition);
        HiveCalciteUtil.JoinPredicateInfo otherJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, otherCondition);
        if (joinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema().equals(otherJoinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema())) {
            return false;
        }
        return !joinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema().equals(otherJoinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema());
    }
}

