package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hive.com.google.common.collect.ImmutableList;
import org.apache.hive.com.google.common.collect.Lists;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.class */
public class HiveJoinToMultiJoinRule extends RelOptRule {
    public static final HiveJoinToMultiJoinRule INSTANCE;
    private final RelFactories.ProjectFactory projectFactory;
    private static final transient Log LOG;
    static final /* synthetic */ boolean $assertionsDisabled;

    public HiveJoinToMultiJoinRule(Class<? extends Join> cls, RelFactories.ProjectFactory projectFactory) {
        super(operand(cls, operand(RelNode.class, any()), new RelOptRuleOperand[]{operand(RelNode.class, any())}));
        this.projectFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Join input;
        HiveJoin hiveJoin = (HiveJoin) relOptRuleCall.rel(0);
        RelNode rel = relOptRuleCall.rel(1);
        RelNode rel2 = relOptRuleCall.rel(2);
        RelNode mergeJoin = mergeJoin(hiveJoin, rel, rel2);
        if (mergeJoin != null) {
            relOptRuleCall.transformTo(mergeJoin);
            return;
        }
        Join swap = JoinCommuteRule.swap(hiveJoin, true);
        if (!$assertionsDisabled && swap == null) {
            throw new AssertionError();
        }
        Project project = null;
        if (swap instanceof Join) {
            input = swap;
        } else {
            project = (Project) swap;
            input = swap.getInput(0);
        }
        RelNode mergeJoin2 = mergeJoin(input, rel2, rel);
        if (mergeJoin2 != null) {
            if (project != null) {
                mergeJoin2 = this.projectFactory.createProject(mergeJoin2, project.getChildExps(), project.getRowType().getFieldNames());
            }
            relOptRuleCall.transformTo(mergeJoin2);
        }
    }

    private static RelNode mergeJoin(Join join, RelNode relNode, RelNode relNode2) {
        RexNode condition;
        List<Pair<Integer, Integer>> joinInputs;
        List<JoinRelType> joinTypes;
        boolean z;
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        newArrayList2.add(join.getCondition());
        ArrayList newArrayList3 = Lists.newArrayList();
        if (!(relNode instanceof Join) && !(relNode instanceof HiveMultiJoin)) {
            return null;
        }
        if (relNode instanceof Join) {
            Join join2 = (Join) relNode;
            condition = join2.getCondition();
            joinInputs = ImmutableList.of(Pair.of(0, 1));
            joinTypes = ImmutableList.of(join2.getJoinType());
        } else {
            HiveMultiJoin hiveMultiJoin = (HiveMultiJoin) relNode;
            condition = hiveMultiJoin.getCondition();
            joinInputs = hiveMultiJoin.getJoinInputs();
            joinTypes = hiveMultiJoin.getJoinTypes();
        }
        try {
            z = isCombinablePredicate(join, join.getCondition(), condition);
        } catch (CalciteSemanticException e) {
            LOG.trace("Failed to merge joins", e);
            z = false;
        }
        if (!z) {
            return null;
        }
        newArrayList2.add(condition);
        for (int i = 0; i < joinInputs.size(); i++) {
            newArrayList3.add(Pair.of(joinInputs.get(i), joinTypes.get(i)));
        }
        newArrayList.addAll(relNode.getInputs());
        int size = newArrayList.size();
        newArrayList.add(relNode2);
        if (newArrayList2.size() == 1) {
            return null;
        }
        ImmutableList of = ImmutableList.of();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < newArrayList.size(); i2++) {
            arrayList.add(new ArrayList());
        }
        try {
            if (!HiveRelOptUtil.splitHiveJoinCondition(of, newArrayList, join.getCondition(), arrayList, arrayList2, null).isAlwaysTrue()) {
                return null;
            }
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            for (int i3 = 0; i3 < newArrayList.size(); i3++) {
                if (!((List) arrayList.get(i3)).isEmpty()) {
                    builder.set(i3);
                }
            }
            ImmutableBitSet build = builder.build();
            ImmutableBitSet intersect = build.intersect(ImmutableBitSet.range(size));
            ImmutableBitSet intersect2 = build.intersect(ImmutableBitSet.range(size, newArrayList.size()));
            if (join.getJoinType() != JoinRelType.INNER && (intersect.cardinality() > 1 || intersect2.cardinality() > 1)) {
                return null;
            }
            if (join.getJoinType() != JoinRelType.INNER) {
                newArrayList3.add(Pair.of(Pair.of(Integer.valueOf(build.nextSetBit(0)), Integer.valueOf(build.nextSetBit(size))), join.getJoinType()));
            } else {
                Iterator it = intersect.iterator();
                while (it.hasNext()) {
                    int intValue = ((Integer) it.next()).intValue();
                    Iterator it2 = intersect2.iterator();
                    while (it2.hasNext()) {
                        newArrayList3.add(Pair.of(Pair.of(Integer.valueOf(intValue), Integer.valueOf(((Integer) it2.next()).intValue())), join.getJoinType()));
                    }
                }
            }
            return new HiveMultiJoin(join.getCluster(), newArrayList, RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newArrayList2, false)), join.getRowType(), Pair.left(newArrayList3), Pair.right(newArrayList3));
        } catch (CalciteSemanticException e2) {
            LOG.trace("Failed to merge joins", e2);
            return null;
        }
    }

    private static boolean isCombinablePredicate(Join join, RexNode rexNode, RexNode rexNode2) throws CalciteSemanticException {
        HiveCalciteUtil.JoinPredicateInfo constructJoinPredicateInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, rexNode);
        HiveCalciteUtil.JoinPredicateInfo constructJoinPredicateInfo2 = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, rexNode2);
        return (constructJoinPredicateInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema().equals(constructJoinPredicateInfo2.getProjsFromLeftPartOfJoinKeysInJoinSchema()) || constructJoinPredicateInfo.getProjsFromRightPartOfJoinKeysInJoinSchema().equals(constructJoinPredicateInfo2.getProjsFromRightPartOfJoinKeysInJoinSchema())) ? false : true;
    }

    static {
        $assertionsDisabled = !HiveJoinToMultiJoinRule.class.desiredAssertionStatus();
        INSTANCE = new HiveJoinToMultiJoinRule(HiveJoin.class, HiveProject.DEFAULT_PROJECT_FACTORY);
        LOG = LogFactory.getLog(HiveJoinToMultiJoinRule.class);
    }
}
