/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.LogicVisitor;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlQuantifyOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

public abstract class SubQueryRemoveRule
extends RelOptRule {
    public static final SubQueryRemoveRule PROJECT = new SubQueryProjectRemoveRule(RelFactories.LOGICAL_BUILDER);
    public static final SubQueryRemoveRule FILTER = new SubQueryFilterRemoveRule(RelFactories.LOGICAL_BUILDER);
    public static final SubQueryRemoveRule JOIN = new SubQueryJoinRemoveRule(RelFactories.LOGICAL_BUILDER);

    public SubQueryRemoveRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    protected RexNode apply(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder, int inputCount, int offset) {
        switch (e.getKind()) {
            case SCALAR_QUERY: {
                return this.rewriteScalarQuery(e, variablesSet, builder, inputCount, offset);
            }
            case SOME: {
                return this.rewriteSome(e, variablesSet, builder);
            }
            case IN: {
                return this.rewriteIn(e, variablesSet, logic, builder, offset);
            }
            case EXISTS: {
                return this.rewriteExists(e, variablesSet, logic, builder);
            }
        }
        throw new AssertionError((Object)e.getKind());
    }

    private RexNode rewriteScalarQuery(RexSubQuery e, Set<CorrelationId> variablesSet, RelBuilder builder, int inputCount, int offset) {
        builder.push(e.rel);
        Object mq = e.rel.getCluster().getMetadataQuery();
        Boolean unique = ((RelMetadataQuery)mq).areColumnsUnique(builder.peek(), ImmutableBitSet.of());
        if (unique == null || !unique.booleanValue()) {
            builder.aggregate(builder.groupKey(), builder.aggregateCall(SqlStdOperatorTable.SINGLE_VALUE, builder.field(0)));
        }
        builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet);
        return this.field(builder, inputCount, offset);
    }

    private RexNode rewriteSome(RexSubQuery e, Set<CorrelationId> variablesSet, RelBuilder builder) {
        RexNode caseRexNode;
        SqlAggFunction minMax;
        SqlQuantifyOperator op = (SqlQuantifyOperator)e.op;
        assert (op == SqlStdOperatorTable.SOME_GE || op == SqlStdOperatorTable.SOME_LE || op == SqlStdOperatorTable.SOME_LT || op == SqlStdOperatorTable.SOME_GT);
        RexNode literalFalse = builder.literal(false);
        RexNode literalTrue = builder.literal(true);
        RexLiteral literalUnknown = builder.getRexBuilder().makeNullLiteral(literalFalse.getType());
        SqlAggFunction sqlAggFunction = minMax = op.comparisonKind == SqlKind.GREATER_THAN || op.comparisonKind == SqlKind.GREATER_THAN_OR_EQUAL ? SqlStdOperatorTable.MIN : SqlStdOperatorTable.MAX;
        if (variablesSet.isEmpty()) {
            builder.push(e.rel).aggregate(builder.groupKey(), builder.aggregateCall(minMax, builder.field(0)).as("m"), builder.count(false, "c", new RexNode[0]), builder.count(false, "d", builder.field(0))).as("q").join(JoinRelType.INNER, new String[0]);
            caseRexNode = builder.call((SqlOperator)SqlStdOperatorTable.CASE, builder.call((SqlOperator)SqlStdOperatorTable.EQUALS, builder.field("q", "c"), builder.literal(0)), literalFalse, builder.call((SqlOperator)SqlStdOperatorTable.IS_TRUE, builder.call(RelOptUtil.op(op.comparisonKind, null), (RexNode)e.operands.get(0), builder.field("q", "m"))), literalTrue, builder.call((SqlOperator)SqlStdOperatorTable.GREATER_THAN, builder.field("q", "c"), builder.field("q", "d")), literalUnknown, builder.call(RelOptUtil.op(op.comparisonKind, null), (RexNode)e.operands.get(0), builder.field("q", "m")));
        } else {
            builder.push(e.rel).aggregate(builder.groupKey(), builder.aggregateCall(minMax, builder.field(0)).as("m"), builder.count(false, "c", new RexNode[0]), builder.count(false, "d", builder.field(0)));
            ArrayList<RexNode> parentQueryFields = new ArrayList<RexNode>();
            parentQueryFields.addAll((Collection<RexNode>)builder.fields());
            String indicator = "trueLiteral";
            parentQueryFields.add(builder.alias(literalTrue, indicator));
            builder.project(parentQueryFields).as("q");
            builder.join(JoinRelType.LEFT, literalTrue, variablesSet);
            caseRexNode = builder.call((SqlOperator)SqlStdOperatorTable.CASE, builder.call((SqlOperator)SqlStdOperatorTable.IS_NULL, builder.field("q", indicator)), literalFalse, builder.call((SqlOperator)SqlStdOperatorTable.EQUALS, builder.field("q", "c"), builder.literal(0)), literalFalse, builder.call((SqlOperator)SqlStdOperatorTable.IS_TRUE, builder.call(RelOptUtil.op(op.comparisonKind, null), (RexNode)e.operands.get(0), builder.field("q", "m"))), literalTrue, builder.call((SqlOperator)SqlStdOperatorTable.GREATER_THAN, builder.field("q", "c"), builder.field("q", "d")), literalUnknown, builder.call(RelOptUtil.op(op.comparisonKind, null), (RexNode)e.operands.get(0), builder.field("q", "m")));
        }
        if (!e.getType().isNullable()) {
            return builder.cast(caseRexNode, e.getType().getSqlTypeName());
        }
        return caseRexNode;
    }

    private RexNode rewriteExists(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder) {
        builder.push(e.rel);
        builder.project(builder.alias(builder.literal(true), "i"));
        switch (logic) {
            case TRUE: {
                builder.aggregate(builder.groupKey(0), new RelBuilder.AggCall[0]);
                builder.as("dt");
                builder.join(JoinRelType.INNER, builder.literal(true), variablesSet);
                return builder.literal(true);
            }
        }
        builder.distinct();
        builder.as("dt");
        builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet);
        return builder.isNotNull(Util.last(builder.fields()));
    }

    private RexNode rewriteIn(RexSubQuery e, Set<CorrelationId> variablesSet, RelOptUtil.Logic logic, RelBuilder builder, int offset) {
        builder.push(e.rel);
        ArrayList<RexNode> fields = new ArrayList<RexNode>((Collection<RexNode>)builder.fields());
        boolean allLiterals = RexUtil.allLiterals(e.getOperands());
        ArrayList<RexNode> expressionOperands = new ArrayList<RexNode>(e.getOperands());
        List keyIsNulls = e.getOperands().stream().filter(operand -> operand.getType().isNullable()).map(builder::isNull).collect(Collectors.toList());
        RexLiteral trueLiteral = (RexLiteral)builder.literal(true);
        RexLiteral falseLiteral = (RexLiteral)builder.literal(false);
        RexLiteral unknownLiteral = builder.getRexBuilder().makeNullLiteral(trueLiteral.getType());
        if (allLiterals) {
            List conditions = Pair.zip(expressionOperands, fields).stream().map(pair -> builder.equals((RexNode)pair.left, (RexNode)pair.right)).collect(Collectors.toList());
            switch (logic) {
                case TRUE: 
                case TRUE_FALSE: {
                    builder.filter(conditions);
                    builder.project(builder.alias(trueLiteral, "cs"));
                    builder.distinct();
                    break;
                }
                default: {
                    List isNullOpperands = fields.stream().map(builder::isNull).collect(Collectors.toList());
                    isNullOpperands.addAll(keyIsNulls);
                    builder.filter(builder.or(builder.and(conditions), builder.or(isNullOpperands)));
                    RexNode project = builder.and(fields.stream().map(builder::isNotNull).collect(Collectors.toList()));
                    builder.project(builder.alias(project, "cs"));
                    if (variablesSet.isEmpty()) {
                        builder.aggregate(builder.groupKey(builder.field("cs")), builder.count(false, "c", new RexNode[0]));
                        builder.sortLimit(0, 1, (Iterable<? extends RexNode>)ImmutableList.of((Object)builder.call((SqlOperator)SqlStdOperatorTable.DESC, builder.field("cs"))));
                        break;
                    }
                    builder.distinct();
                }
            }
            expressionOperands.clear();
            fields.clear();
        } else {
            switch (logic) {
                case TRUE: {
                    builder.aggregate(builder.groupKey(fields), new RelBuilder.AggCall[0]);
                    break;
                }
                case TRUE_FALSE_UNKNOWN: 
                case UNKNOWN_AS_TRUE: {
                    builder.aggregate(builder.groupKey(), builder.count(false, "c", new RexNode[0]), builder.count((Iterable<? extends RexNode>)builder.fields()).as("ck"));
                    builder.as("ct");
                    if (!variablesSet.isEmpty()) {
                        builder.join(JoinRelType.LEFT, (RexNode)trueLiteral, variablesSet);
                    } else {
                        builder.join(JoinRelType.INNER, (RexNode)trueLiteral, variablesSet);
                    }
                    offset += 2;
                    builder.push(e.rel);
                }
                default: {
                    fields.add(builder.alias(trueLiteral, "i"));
                    builder.project(fields);
                    builder.distinct();
                }
            }
        }
        builder.as("dt");
        int refOffset = offset;
        List conditions = Pair.zip(expressionOperands, builder.fields()).stream().map(pair -> builder.equals((RexNode)pair.left, RexUtil.shift((RexNode)pair.right, refOffset))).collect(Collectors.toList());
        switch (logic) {
            case TRUE: {
                builder.join(JoinRelType.INNER, builder.and(conditions), variablesSet);
                return trueLiteral;
            }
        }
        builder.join(JoinRelType.LEFT, builder.and(conditions), variablesSet);
        ImmutableList.Builder operands = ImmutableList.builder();
        RexLiteral b = trueLiteral;
        switch (logic) {
            case TRUE_FALSE_UNKNOWN: {
                b = unknownLiteral;
            }
            case UNKNOWN_AS_TRUE: {
                if (allLiterals) {
                    if (variablesSet.isEmpty()) {
                        operands.add((Object[])new RexNode[]{builder.isNull(builder.field("c")), falseLiteral});
                    }
                    operands.add((Object[])new RexNode[]{builder.equals(builder.field("cs"), falseLiteral), b});
                    break;
                }
                operands.add((Object[])new RexNode[]{builder.equals(builder.field("ct", "c"), builder.literal(0)), falseLiteral});
            }
        }
        if (!keyIsNulls.isEmpty()) {
            operands.add((Object[])new RexNode[]{builder.or(keyIsNulls), unknownLiteral});
        }
        if (allLiterals) {
            operands.add((Object[])new RexNode[]{builder.isNotNull(builder.field("cs")), trueLiteral});
        } else {
            operands.add((Object[])new RexNode[]{builder.isNotNull(Util.last(builder.fields())), trueLiteral});
        }
        if (!allLiterals) {
            switch (logic) {
                case TRUE_FALSE_UNKNOWN: 
                case UNKNOWN_AS_TRUE: {
                    operands.add((Object[])new RexNode[]{builder.call((SqlOperator)SqlStdOperatorTable.LESS_THAN, builder.field("ct", "ck"), builder.field("ct", "c")), b});
                }
            }
        }
        operands.add((Object)falseLiteral);
        return builder.call((SqlOperator)SqlStdOperatorTable.CASE, (Iterable<? extends RexNode>)operands.build());
    }

    private RexInputRef field(RelBuilder builder, int inputCount, int offset) {
        int inputOrdinal = 0;
        RelNode r;
        while (offset >= (r = builder.peek(inputCount, inputOrdinal)).getRowType().getFieldCount()) {
            ++inputOrdinal;
            offset -= r.getRowType().getFieldCount();
        }
        return builder.field(inputCount, inputOrdinal, offset);
    }

    private static List<RexNode> fields(RelBuilder builder, int fieldCount) {
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        for (int i = 0; i < fieldCount; ++i) {
            projects.add(builder.field(i));
        }
        return projects;
    }

    private static class ReplaceSubQueryShuttle
    extends RexShuttle {
        private final RexSubQuery subQuery;
        private final RexNode replacement;

        ReplaceSubQueryShuttle(RexSubQuery subQuery, RexNode replacement) {
            this.subQuery = subQuery;
            this.replacement = replacement;
        }

        @Override
        public RexNode visitSubQuery(RexSubQuery subQuery) {
            return subQuery.equals(this.subQuery) ? this.replacement : subQuery;
        }
    }

    public static class SubQueryJoinRemoveRule
    extends SubQueryRemoveRule {
        public SubQueryJoinRemoveRule(RelBuilderFactory relBuilderFactory) {
            super(SubQueryJoinRemoveRule.operandJ(Join.class, null, RexUtil.SubQueryFinder::containsSubQuery, SubQueryJoinRemoveRule.any()), relBuilderFactory, "SubQueryRemoveRule:Join");
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Join join = (Join)call.rel(0);
            RelBuilder builder = call.builder();
            RexSubQuery e = RexUtil.SubQueryFinder.find(join.getCondition());
            assert (e != null);
            RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE, (List<RexNode>)ImmutableList.of((Object)join.getCondition()), e);
            builder.push(join.getLeft());
            builder.push(join.getRight());
            int fieldCount = join.getRowType().getFieldCount();
            Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
            RexNode target = this.apply(e, variablesSet, logic, builder, 2, fieldCount);
            ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
            builder.join(join.getJoinType(), shuttle.apply(join.getCondition()));
            builder.project(SubQueryRemoveRule.fields(builder, join.getRowType().getFieldCount()));
            call.transformTo(builder.build());
        }
    }

    public static class SubQueryFilterRemoveRule
    extends SubQueryRemoveRule {
        public SubQueryFilterRemoveRule(RelBuilderFactory relBuilderFactory) {
            super(SubQueryFilterRemoveRule.operandJ(Filter.class, null, RexUtil.SubQueryFinder::containsSubQuery, SubQueryFilterRemoveRule.any()), relBuilderFactory, "SubQueryRemoveRule:Filter");
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Filter filter = (Filter)call.rel(0);
            RelBuilder builder = call.builder();
            builder.push(filter.getInput());
            int count = 0;
            RexNode c = filter.getCondition();
            while (true) {
                RexSubQuery e;
                if ((e = RexUtil.SubQueryFinder.find(c)) == null) {
                    assert (count > 0);
                    break;
                }
                ++count;
                RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE, (List<RexNode>)ImmutableList.of((Object)c), e);
                Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
                RexNode target = this.apply(e, variablesSet, logic, builder, 1, builder.peek().getRowType().getFieldCount());
                ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
                c = c.accept(shuttle);
            }
            builder.filter(c);
            builder.project(SubQueryRemoveRule.fields(builder, filter.getRowType().getFieldCount()));
            call.transformTo(builder.build());
        }
    }

    public static class SubQueryProjectRemoveRule
    extends SubQueryRemoveRule {
        public SubQueryProjectRemoveRule(RelBuilderFactory relBuilderFactory) {
            super(SubQueryProjectRemoveRule.operandJ(Project.class, null, RexUtil.SubQueryFinder::containsSubQuery, SubQueryProjectRemoveRule.any()), relBuilderFactory, "SubQueryRemoveRule:Project");
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Project project = (Project)call.rel(0);
            RelBuilder builder = call.builder();
            RexSubQuery e = RexUtil.SubQueryFinder.find(project.getProjects());
            assert (e != null);
            RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE_FALSE_UNKNOWN, project.getProjects(), e);
            builder.push(project.getInput());
            int fieldCount = builder.peek().getRowType().getFieldCount();
            Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
            RexNode target = this.apply(e, variablesSet, logic, builder, 1, fieldCount);
            ReplaceSubQueryShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
            builder.project(shuttle.apply(project.getProjects()), project.getRowType().getFieldNames());
            call.transformTo(builder.build());
        }
    }
}

