/*
 * Decompiled with CFR 0.152.
 */
package org.eigenbase.rel.rules;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.List;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRel;
import org.eigenbase.rel.Aggregation;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.UnionRel;
import org.eigenbase.rel.metadata.RelMdUtil;
import org.eigenbase.relopt.RelOptCluster;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.reltype.RelDataTypeFactory;
import org.eigenbase.sql.fun.SqlSumAggFunction;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class PushAggregateThroughUnionRule
extends RelOptRule {
    public static final PushAggregateThroughUnionRule INSTANCE = new PushAggregateThroughUnionRule();

    private PushAggregateThroughUnionRule() {
        super(PushAggregateThroughUnionRule.operand(AggregateRel.class, PushAggregateThroughUnionRule.operand(UnionRel.class, PushAggregateThroughUnionRule.any()), new RelOptRuleOperand[0]));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        AggregateRel aggRel = (AggregateRel)call.rel(0);
        UnionRel unionRel = (UnionRel)call.rel(1);
        if (!unionRel.all) {
            return;
        }
        RelOptCluster cluster = unionRel.getCluster();
        BitSet groupByKeyMask = new BitSet();
        int i = 0;
        while (i < aggRel.getGroupCount()) {
            groupByKeyMask.set(i);
            ++i;
        }
        List<AggregateCall> transformedAggCalls = this.transformAggCalls(aggRel.getCluster().getTypeFactory(), aggRel.getGroupSet().cardinality(), aggRel.getAggCallList());
        if (transformedAggCalls == null) {
            return;
        }
        boolean anyTransformed = false;
        ArrayList<RelNode> newUnionInputs = new ArrayList<RelNode>();
        for (RelNode input : unionRel.getInputs()) {
            boolean alreadyUnique = RelMdUtil.areColumnsDefinitelyUnique(input, aggRel.getGroupSet());
            if (alreadyUnique) {
                newUnionInputs.add(input);
                continue;
            }
            anyTransformed = true;
            newUnionInputs.add(new AggregateRel(cluster, input, aggRel.getGroupSet(), aggRel.getAggCallList()));
        }
        if (!anyTransformed) {
            return;
        }
        UnionRel newUnionRel = new UnionRel(cluster, newUnionInputs, true);
        AggregateRel newTopAggRel = new AggregateRel(cluster, newUnionRel, aggRel.getGroupSet(), transformedAggCalls);
        RelNode castRel = RelOptUtil.createCastRel(newTopAggRel, aggRel.getRowType(), false);
        call.transformTo(castRel);
    }

    private List<AggregateCall> transformAggCalls(RelDataTypeFactory typeFactory, int nGroupCols, List<AggregateCall> origCalls) {
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        int iInput = nGroupCols;
        for (AggregateCall origCall : origCalls) {
            Aggregation aggFun;
            RelDataType aggType;
            if (origCall.isDistinct()) {
                return null;
            }
            if (origCall.getAggregation().getName().equals("AVG")) {
                return null;
            }
            if (origCall.getAggregation().getName().equals("COUNT")) {
                aggType = typeFactory.createTypeWithNullability(origCall.getType(), true);
                aggFun = new SqlSumAggFunction(aggType);
            } else {
                aggFun = origCall.getAggregation();
                aggType = origCall.getType();
            }
            AggregateCall newCall = new AggregateCall(aggFun, origCall.isDistinct(), Collections.singletonList(iInput), aggType, origCall.getName());
            newCalls.add(newCall);
            ++iInput;
        }
        return newCalls;
    }
}

