package org.eigenbase.rel.rules;

import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import org.eigenbase.rel.AggregateCall;
import org.eigenbase.rel.AggregateRel;
import org.eigenbase.rel.AggregateRelBase;
import org.eigenbase.rel.Aggregation;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.UnionRel;
import org.eigenbase.rel.metadata.RelMdUtil;
import org.eigenbase.relopt.RelOptCluster;
import org.eigenbase.relopt.RelOptRule;
import org.eigenbase.relopt.RelOptRuleCall;
import org.eigenbase.relopt.RelOptRuleOperand;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.reltype.RelDataType;
import org.eigenbase.reltype.RelDataTypeFactory;
import org.eigenbase.sql.SqlAggFunction;
import org.eigenbase.sql.fun.SqlCountAggFunction;
import org.eigenbase.sql.fun.SqlMinMaxAggFunction;
import org.eigenbase.sql.fun.SqlSumAggFunction;
import org.eigenbase.sql.fun.SqlSumEmptyIsZeroAggFunction;

/* loaded from: input_file:org/eigenbase/rel/rules/PushAggregateThroughUnionRule.class */
public class PushAggregateThroughUnionRule extends RelOptRule {
    public static final PushAggregateThroughUnionRule INSTANCE = new PushAggregateThroughUnionRule();
    private static final Map<Class, Boolean> SUPPORTED_AGGREGATES = new IdentityHashMap();

    private PushAggregateThroughUnionRule() {
        super(operand((Class<? extends RelNode>) AggregateRel.class, operand(UnionRel.class, any()), new RelOptRuleOperand[0]));
    }

    @Override // org.eigenbase.relopt.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        AggregateRel aggregateRel = (AggregateRel) relOptRuleCall.rel(0);
        UnionRel unionRel = (UnionRel) relOptRuleCall.rel(1);
        if (unionRel.all) {
            RelOptCluster cluster = unionRel.getCluster();
            List<AggregateCall> transformAggCalls = transformAggCalls(aggregateRel.getCluster().getTypeFactory(), aggregateRel.getGroupSet().cardinality(), aggregateRel.getAggCallList());
            if (transformAggCalls == null) {
                return;
            }
            boolean z = false;
            ArrayList arrayList = new ArrayList();
            for (RelNode relNode : unionRel.getInputs()) {
                if (RelMdUtil.areColumnsDefinitelyUnique(relNode, aggregateRel.getGroupSet())) {
                    arrayList.add(relNode);
                } else {
                    z = true;
                    arrayList.add(new AggregateRel(cluster, relNode, aggregateRel.getGroupSet(), aggregateRel.getAggCallList()));
                }
            }
            if (z) {
                relOptRuleCall.transformTo(RelOptUtil.createCastRel(new AggregateRel(cluster, new UnionRel(cluster, arrayList, true), aggregateRel.getGroupSet(), transformAggCalls), aggregateRel.getRowType(), false));
            }
        }
    }

    private List<AggregateCall> transformAggCalls(RelDataTypeFactory relDataTypeFactory, int i, List<AggregateCall> list) {
        Aggregation aggregation;
        RelDataType type;
        ArrayList arrayList = new ArrayList();
        int i2 = i;
        for (AggregateCall aggregateCall : list) {
            if (aggregateCall.isDistinct() || !SUPPORTED_AGGREGATES.containsKey(aggregateCall.getAggregation().getClass())) {
                return null;
            }
            if (aggregateCall.getAggregation().getName().equals("COUNT")) {
                aggregation = new SqlSumEmptyIsZeroAggFunction(aggregateCall.getType());
                SqlAggFunction sqlAggFunction = (SqlAggFunction) aggregation;
                type = sqlAggFunction.inferReturnType(new AggregateRelBase.AggCallBinding(relDataTypeFactory, sqlAggFunction, Collections.singletonList(aggregateCall.getType()), i));
            } else {
                aggregation = aggregateCall.getAggregation();
                type = aggregateCall.getType();
            }
            arrayList.add(new AggregateCall(aggregation, aggregateCall.isDistinct(), Collections.singletonList(Integer.valueOf(i2)), type, aggregateCall.getName()));
            i2++;
        }
        return arrayList;
    }

    static {
        SUPPORTED_AGGREGATES.put(SqlMinMaxAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlCountAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumAggFunction.class, true);
        SUPPORTED_AGGREGATES.put(SqlSumEmptyIsZeroAggFunction.class, true);
    }
}
