/*
 * Decompiled with CFR 0.152.
 */
package org.eigenbase.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import net.hydromatic.optiq.util.BitSets;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRel;
import org.eigenbase.rel.CalcRel;
import org.eigenbase.rel.ProjectRel;
import org.eigenbase.rel.RelNode;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.reltype.RelDataTypeField;
import org.eigenbase.rex.RexBuilder;
import org.eigenbase.rex.RexLocalRef;
import org.eigenbase.rex.RexNode;
import org.eigenbase.rex.RexProgram;
import org.eigenbase.util.IntList;
import org.eigenbase.util.Pair;
import org.eigenbase.util.Permutation;
import org.eigenbase.util.mapping.Mapping;
import org.eigenbase.util.mapping.MappingType;

public class PullConstantsThroughAggregatesRule
extends RelOptRule {
    public static final PullConstantsThroughAggregatesRule INSTANCE = new PullConstantsThroughAggregatesRule();

    private PullConstantsThroughAggregatesRule() {
        super(PullConstantsThroughAggregatesRule.operand(AggregateRel.class, PullConstantsThroughAggregatesRule.operand(ProjectRel.class, PullConstantsThroughAggregatesRule.any()), new RelOptRuleOperand[0]));
    }

    public void onMatch(RelOptRuleCall call) {
        AggregateRel newAggregate;
        AggregateRel aggregate = (AggregateRel)call.rel(0);
        ProjectRel child = (ProjectRel)call.rel(1);
        int groupCount = aggregate.getGroupCount();
        if (groupCount == 1) {
            return;
        }
        RexProgram program = RexProgram.create(child.getChild().getRowType(), child.getProjects(), null, child.getRowType(), child.getCluster().getRexBuilder());
        RelDataType childRowType = child.getRowType();
        IntList constantList = new IntList();
        HashMap<Integer, RexNode> constants = new HashMap<Integer, RexNode>();
        for (int i : BitSets.toIter(aggregate.getGroupSet())) {
            RexLocalRef ref = program.getProjectList().get(i);
            if (!program.isConstant(ref)) continue;
            constantList.add(i);
            constants.put(i, program.gatherExpr(ref));
        }
        if (constantList.size() == 0) {
            return;
        }
        if (groupCount == constantList.size()) {
            constantList.remove(0);
        }
        int newGroupCount = groupCount - constantList.size();
        if ((Integer)constantList.get(0) == newGroupCount) {
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                newAggCalls.add(aggCall.adaptTo(child, aggCall.getArgList(), groupCount, newGroupCount));
            }
            newAggregate = new AggregateRel(aggregate.getCluster(), child, BitSets.range(newGroupCount), newAggCalls);
        } else {
            Permutation mapping = new Permutation(childRowType.getFieldCount());
            mapping.identity();
            int i = 0;
            int groupOrdinal = 0;
            int constOrdinal = newGroupCount;
            while (i < groupCount) {
                if (i >= groupCount) {
                    mapping.set(i, i);
                } else if (constants.containsKey(i)) {
                    mapping.set(i, constOrdinal++);
                } else {
                    mapping.set(i, groupOrdinal++);
                }
                ++i;
            }
            RelNode project = PullConstantsThroughAggregatesRule.createProjection(mapping, child);
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                int argCount = aggCall.getArgList().size();
                ArrayList<Integer> args = new ArrayList<Integer>(argCount);
                int j = 0;
                while (j < argCount) {
                    Integer arg = aggCall.getArgList().get(j);
                    args.add(mapping.getTarget(arg));
                    ++j;
                }
                newAggCalls.add(aggCall.adaptTo(project, args, groupCount, newGroupCount));
            }
            newAggregate = new AggregateRel(aggregate.getCluster(), project, BitSets.range(newGroupCount), newAggCalls);
        }
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        int source = 0;
        for (RelDataTypeField field : aggregate.getRowType().getFieldList()) {
            RexNode expr;
            int i = field.getIndex();
            if (i >= groupCount) {
                expr = rexBuilder.makeInputRef(newAggregate, i - constantList.size());
            } else if (constantList.contains(i)) {
                expr = (RexNode)constants.get(i);
            } else {
                expr = rexBuilder.makeInputRef(newAggregate, source);
                ++source;
            }
            projects.add(Pair.of(expr, field.getName()));
        }
        RelNode inverseProject = CalcRel.createProject((RelNode)newAggregate, projects, false);
        call.transformTo(inverseProject);
    }

    private static RelNode createProjection(Mapping mapping, RelNode child) {
        assert (mapping.getMappingType().isA(MappingType.INVERSE_SURJECTION));
        RelDataType childRowType = child.getRowType();
        assert (mapping.getSourceCount() == childRowType.getFieldCount());
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        int target = 0;
        while (target < mapping.getTargetCount()) {
            int source = mapping.getSource(target);
            RexBuilder rexBuilder = child.getCluster().getRexBuilder();
            projects.add(Pair.of(rexBuilder.makeInputRef(child, source), childRowType.getFieldList().get(source).getName()));
            ++target;
        }
        return CalcRel.createProject(child, projects, false);
    }
}

