package org.apache.hive.druid.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Predicate;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptRule;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptUtil;
import org.apache.hive.druid.org.apache.calcite.rel.RelNode;
import org.apache.hive.druid.org.apache.calcite.rel.core.Aggregate;
import org.apache.hive.druid.org.apache.calcite.rel.core.Join;
import org.apache.hive.druid.org.apache.calcite.rel.core.JoinInfo;
import org.apache.hive.druid.org.apache.calcite.rel.core.Project;
import org.apache.hive.druid.org.apache.calcite.rel.core.RelFactories;
import org.apache.hive.druid.org.apache.calcite.rex.RexBuilder;
import org.apache.hive.druid.org.apache.calcite.tools.RelBuilder;
import org.apache.hive.druid.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableBitSet;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableIntList;

/* loaded from: input_file:org/apache/hive/druid/org/apache/calcite/rel/rules/SemiJoinRule.class */
public abstract class SemiJoinRule extends RelOptRule {
    private static final Predicate<Join> NOT_GENERATE_NULLS_ON_LEFT = join -> {
        return !join.getJoinType().generatesNullsOnLeft();
    };
    private static final Predicate<Aggregate> IS_EMPTY_AGGREGATE = aggregate -> {
        return aggregate.getRowType().getFieldCount() == 0;
    };
    public static final SemiJoinRule PROJECT = new ProjectToSemiJoinRule(Project.class, Join.class, Aggregate.class, RelFactories.LOGICAL_BUILDER, "SemiJoinRule:project");
    public static final SemiJoinRule JOIN = new JoinToSemiJoinRule(Join.class, Aggregate.class, RelFactories.LOGICAL_BUILDER, "SemiJoinRule:join");

    /* loaded from: input_file:org/apache/hive/druid/org/apache/calcite/rel/rules/SemiJoinRule$JoinToSemiJoinRule.class */
    public static class JoinToSemiJoinRule extends SemiJoinRule {
        public JoinToSemiJoinRule(Class<Join> cls, Class<Aggregate> cls2, RelBuilderFactory relBuilderFactory, String str) {
            super(cls, cls2, relBuilderFactory, str);
        }
    }

    /* loaded from: input_file:org/apache/hive/druid/org/apache/calcite/rel/rules/SemiJoinRule$ProjectToSemiJoinRule.class */
    public static class ProjectToSemiJoinRule extends SemiJoinRule {
        public ProjectToSemiJoinRule(Class<Project> cls, Class<Join> cls2, Class<Aggregate> cls3, RelBuilderFactory relBuilderFactory, String str) {
            super(cls, cls2, cls3, relBuilderFactory, str);
        }

        @Override // org.apache.hive.druid.org.apache.calcite.rel.rules.SemiJoinRule, org.apache.hive.druid.org.apache.calcite.plan.RelOptRule
        public void onMatch(RelOptRuleCall relOptRuleCall) {
            perform(relOptRuleCall, (Project) relOptRuleCall.rel(0), (Join) relOptRuleCall.rel(1), relOptRuleCall.rel(2), (Aggregate) relOptRuleCall.rel(3));
        }
    }

    protected SemiJoinRule(Class<Project> cls, Class<Join> cls2, Class<Aggregate> cls3, RelBuilderFactory relBuilderFactory, String str) {
        super(operand(cls, some(operandJ(cls2, null, NOT_GENERATE_NULLS_ON_LEFT, some(operand(RelNode.class, any()), operand(cls3, any()))), new RelOptRuleOperand[0])), relBuilderFactory, str);
    }

    protected SemiJoinRule(Class<Join> cls, Class<Aggregate> cls2, RelBuilderFactory relBuilderFactory, String str) {
        super(operandJ(cls, null, NOT_GENERATE_NULLS_ON_LEFT, some(operand(RelNode.class, any()), operand(cls2, any()))), relBuilderFactory, str);
    }

    protected void perform(RelOptRuleCall relOptRuleCall, Project project, Join join, RelNode relNode, Aggregate aggregate) {
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        if (project != null) {
            if (RelOptUtil.InputFinder.bits(project.getProjects(), null).intersects(ImmutableBitSet.range(relNode.getRowType().getFieldCount(), join.getRowType().getFieldCount()))) {
                return;
            }
        } else if (join.getJoinType().projectsRight() && !IS_EMPTY_AGGREGATE.test(aggregate)) {
            return;
        }
        JoinInfo analyzeCondition = join.analyzeCondition();
        if (analyzeCondition.rightSet().equals(ImmutableBitSet.range(aggregate.getGroupCount())) && analyzeCondition.isEqui()) {
            RelBuilder builder = relOptRuleCall.builder();
            builder.push(relNode);
            switch (join.getJoinType()) {
                case SEMI:
                case INNER:
                    ArrayList arrayList = new ArrayList();
                    List<Integer> asList = aggregate.getGroupSet().asList();
                    Iterator<Integer> it2 = analyzeCondition.rightKeys.iterator();
                    while (it2.hasNext()) {
                        arrayList.add(asList.get(it2.next().intValue()));
                    }
                    ImmutableIntList copyOf = ImmutableIntList.copyOf((Iterable<? extends Number>) arrayList);
                    builder.push(aggregate.getInput());
                    builder.semiJoin(RelOptUtil.createEquiJoinCondition(builder.peek(2, 0), analyzeCondition.leftKeys, builder.peek(2, 1), copyOf, rexBuilder));
                    break;
                case LEFT:
                    break;
                default:
                    throw new AssertionError(join.getJoinType());
            }
            if (project != null) {
                builder.project(project.getProjects(), project.getRowType().getFieldNames());
            }
            relOptRuleCall.transformTo(builder.build());
        }
    }

    @Override // org.apache.hive.druid.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        perform(relOptRuleCall, null, (Join) relOptRuleCall.rel(0), relOptRuleCall.rel(1), (Aggregate) relOptRuleCall.rel(2));
    }
}
