/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;

public class AggregateUnionTransposeRule
extends RelOptRule {
    public static final AggregateUnionTransposeRule INSTANCE = new AggregateUnionTransposeRule(LogicalAggregate.class, RelFactories.DEFAULT_AGGREGATE_FACTORY, LogicalUnion.class, RelFactories.DEFAULT_SET_OP_FACTORY);
    private final RelFactories.AggregateFactory aggregateFactory;
    private final RelFactories.SetOpFactory setOpFactory;
    private static final Map<Class<? extends SqlAggFunction>, Boolean> SUPPORTED_AGGREGATES = new IdentityHashMap<Class<? extends SqlAggFunction>, Boolean>();

    public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Union> unionClass, RelFactories.SetOpFactory setOpFactory) {
        super(AggregateUnionTransposeRule.operand(aggregateClass, AggregateUnionTransposeRule.operand(unionClass, AggregateUnionTransposeRule.any()), new RelOptRuleOperand[0]));
        this.aggregateFactory = aggregateFactory;
        this.setOpFactory = setOpFactory;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggRel = (Aggregate)call.rel(0);
        Union union = (Union)call.rel(1);
        if (!union.all) {
            return;
        }
        int groupCount = aggRel.getGroupSet().cardinality();
        List<AggregateCall> transformedAggCalls = this.transformAggCalls(aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(), false, aggRel.getGroupSet(), null, aggRel.getAggCallList()), groupCount, aggRel.getAggCallList());
        if (transformedAggCalls == null) {
            return;
        }
        boolean anyTransformed = false;
        ArrayList<RelNode> newUnionInputs = new ArrayList<RelNode>();
        RelMetadataQuery mq = RelMetadataQuery.instance();
        for (RelNode input : union.getInputs()) {
            boolean alreadyUnique = RelMdUtil.areColumnsDefinitelyUnique(mq, input, aggRel.getGroupSet());
            if (alreadyUnique) {
                newUnionInputs.add(input);
                continue;
            }
            anyTransformed = true;
            newUnionInputs.add(this.aggregateFactory.createAggregate(input, false, aggRel.getGroupSet(), null, aggRel.getAggCallList()));
        }
        if (!anyTransformed) {
            return;
        }
        RelNode newUnion = this.setOpFactory.createSetOp(SqlKind.UNION, newUnionInputs, true);
        RelNode newTopAggRel = this.aggregateFactory.createAggregate(newUnion, aggRel.indicator, aggRel.getGroupSet(), aggRel.getGroupSets(), transformedAggCalls);
        call.transformTo(newTopAggRel);
    }

    private List<AggregateCall> transformAggCalls(RelNode input, int groupCount, List<AggregateCall> origCalls) {
        ArrayList newCalls = Lists.newArrayList();
        for (Ord ord : Ord.zip(origCalls)) {
            RelDataType aggType;
            SqlAggFunction aggFun;
            AggregateCall origCall = (AggregateCall)ord.e;
            if (origCall.isDistinct() || !SUPPORTED_AGGREGATES.containsKey(origCall.getAggregation().getClass())) {
                return null;
            }
            if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
                aggFun = SqlStdOperatorTable.SUM0;
                aggType = null;
            } else {
                aggFun = origCall.getAggregation();
                aggType = origCall.getType();
            }
            AggregateCall newCall = AggregateCall.create(aggFun, origCall.isDistinct(), (List<Integer>)ImmutableList.of((Object)(groupCount + ord.i)), groupCount, input, aggType, origCall.getName());
            newCalls.add(newCall);
        }
        return newCalls;
    }

    static {
        SUPPORTED_AGGREGATES.put(SqlMinMaxAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlCountAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumEmptyIsZeroAggFunction.class, true);
    }
}

