package io.confluent.ksql.planner.plan;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.udaf.KudafAggregator;
import io.confluent.ksql.function.udaf.KudafInitializer;
import io.confluent.ksql.metastore.MetastoreUtil;
import io.confluent.ksql.parser.tree.Expression;
import io.confluent.ksql.parser.tree.FunctionCall;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.serde.KsqlTopicSerDe;
import io.confluent.ksql.structured.SchemaKGroupedStream;
import io.confluent.ksql.structured.SchemaKStream;
import io.confluent.ksql.structured.SchemaKTable;
import io.confluent.ksql.util.AggregateExpressionRewriter;
import io.confluent.ksql.util.KafkaTopicClient;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.Pair;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.connect.data.Field;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.streams.StreamsBuilder;

/* loaded from: input_file:io/confluent/ksql/planner/plan/AggregateNode.class */
public class AggregateNode extends PlanNode {
    private final PlanNode source;
    private final Schema schema;
    private final List<Expression> groupByExpressions;
    private final WindowExpression windowExpression;
    private final List<Expression> aggregateFunctionArguments;
    private final List<FunctionCall> functionList;
    private final List<Expression> requiredColumnList;
    private final List<Expression> finalSelectExpressions;
    private final Expression havingExpressions;

    @JsonCreator
    public AggregateNode(@JsonProperty("id") PlanNodeId planNodeId, @JsonProperty("source") PlanNode planNode, @JsonProperty("schema") Schema schema, @JsonProperty("groupby") List<Expression> list, @JsonProperty("window") WindowExpression windowExpression, @JsonProperty("aggregateFunctionArguments") List<Expression> list2, @JsonProperty("functionList") List<FunctionCall> list3, @JsonProperty("requiredColumnList") List<Expression> list4, @JsonProperty("finalSelectExpressions") List<Expression> list5, @JsonProperty("havingExpressions") Expression expression) {
        super(planNodeId);
        this.source = planNode;
        this.schema = schema;
        this.groupByExpressions = list;
        this.windowExpression = windowExpression;
        this.aggregateFunctionArguments = list2;
        this.functionList = list3;
        this.requiredColumnList = list4;
        this.finalSelectExpressions = list5;
        this.havingExpressions = expression;
    }

    @Override // io.confluent.ksql.planner.plan.PlanNode
    public Schema getSchema() {
        return this.schema;
    }

    @Override // io.confluent.ksql.planner.plan.PlanNode
    public Field getKeyField() {
        return null;
    }

    @Override // io.confluent.ksql.planner.plan.PlanNode
    public List<PlanNode> getSources() {
        return ImmutableList.of(this.source);
    }

    public PlanNode getSource() {
        return this.source;
    }

    public List<Expression> getGroupByExpressions() {
        return this.groupByExpressions;
    }

    public WindowExpression getWindowExpression() {
        return this.windowExpression;
    }

    public List<Expression> getAggregateFunctionArguments() {
        return this.aggregateFunctionArguments;
    }

    public List<FunctionCall> getFunctionList() {
        return this.functionList;
    }

    public List<Expression> getRequiredColumnList() {
        return this.requiredColumnList;
    }

    private List<Pair<String, Expression>> getFinalSelectExpressions() {
        ArrayList arrayList = new ArrayList();
        if (this.finalSelectExpressions.size() != this.schema.fields().size()) {
            throw new KsqlException("Incompatible aggregate schema, field count must match, selected field count:" + this.finalSelectExpressions.size() + " schema field count:" + this.schema.fields().size());
        }
        for (int i = 0; i < this.finalSelectExpressions.size(); i++) {
            arrayList.add(new Pair(((Field) this.schema.fields().get(i)).name(), this.finalSelectExpressions.get(i)));
        }
        return arrayList;
    }

    public Expression getHavingExpressions() {
        return this.havingExpressions;
    }

    @Override // io.confluent.ksql.planner.plan.PlanNode
    public <C, R> R accept(PlanVisitor<C, R> planVisitor, C c) {
        return planVisitor.visitAggregate(this, c);
    }

    @Override // io.confluent.ksql.planner.plan.PlanNode
    public SchemaKStream buildStream(StreamsBuilder streamsBuilder, KsqlConfig ksqlConfig, KafkaTopicClient kafkaTopicClient, MetastoreUtil metastoreUtil, FunctionRegistry functionRegistry, Map<String, Object> map, SchemaRegistryClient schemaRegistryClient) {
        StructuredDataSourceNode theSourceNode = getTheSourceNode();
        SchemaKStream buildStream = getSource().buildStream(streamsBuilder, ksqlConfig, kafkaTopicClient, metastoreUtil, functionRegistry, map, schemaRegistryClient);
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        collectAggregateArgExpressions(getRequiredColumnList(), arrayList, hashMap);
        collectAggregateArgExpressions(getAggregateFunctionArguments(), arrayList, hashMap);
        SchemaKStream select = buildStream.select(arrayList);
        KsqlTopicSerDe ksqlTopicSerDe = theSourceNode.getStructuredDataSource().getKsqlTopic().getKsqlTopicSerDe();
        SchemaKGroupedStream groupBy = select.groupBy(Serdes.String(), ksqlTopicSerDe.getGenericRowSerde(select.getSchema(), ksqlConfig, true, schemaRegistryClient), getGroupByExpressions());
        SchemaBuilder struct = SchemaBuilder.struct();
        Map<Integer, Integer> createAggregateValueToValueColumnMap = createAggregateValueToValueColumnMap(select, struct);
        Schema buildAggregateSchema = buildAggregateSchema(select.getSchema(), functionRegistry);
        Serde<GenericRow> genericRowSerde = ksqlTopicSerDe.getGenericRowSerde(buildAggregateSchema, ksqlConfig, true, schemaRegistryClient);
        KudafInitializer kudafInitializer = new KudafInitializer(createAggregateValueToValueColumnMap.size());
        SchemaKTable aggregate = groupBy.aggregate(kudafInitializer, new KudafAggregator(createAggValToFunctionMap(hashMap, select, struct, kudafInitializer, createAggregateValueToValueColumnMap.size(), functionRegistry), createAggregateValueToValueColumnMap), getWindowExpression(), genericRowSerde);
        SchemaKTable schemaKTable = new SchemaKTable(buildAggregateSchema, aggregate.getKtable(), aggregate.getKeyField(), aggregate.getSourceSchemaKStreams(), aggregate.isWindowed(), SchemaKStream.Type.AGGREGATE, functionRegistry, schemaRegistryClient);
        if (getHavingExpressions() != null) {
            schemaKTable = schemaKTable.filter(getHavingExpressions());
        }
        return schemaKTable.select(getFinalSelectExpressions());
    }

    private Map<Integer, Integer> createAggregateValueToValueColumnMap(SchemaKStream schemaKStream, SchemaBuilder schemaBuilder) {
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator<Expression> it = getRequiredColumnList().iterator();
        while (it.hasNext()) {
            int indexInSchema = SchemaUtil.getIndexInSchema(it.next().toString(), schemaKStream.getSchema());
            hashMap.put(Integer.valueOf(i), Integer.valueOf(indexInSchema));
            i++;
            Field field = (Field) schemaKStream.getSchema().fields().get(indexInSchema);
            schemaBuilder.field(field.name(), field.schema());
        }
        return hashMap;
    }

    private void collectAggregateArgExpressions(List<Expression> list, List<Pair<String, Expression>> list2, Map<String, Integer> map) {
        list.stream().filter(expression -> {
            return !map.containsKey(expression.toString());
        }).forEach(expression2 -> {
            map.put(expression2.toString(), Integer.valueOf(list2.size()));
            list2.add(new Pair(expression2.toString(), expression2));
        });
    }

    private Map<Integer, KsqlAggregateFunction> createAggValToFunctionMap(Map<String, Integer> map, SchemaKStream schemaKStream, SchemaBuilder schemaBuilder, KudafInitializer kudafInitializer, int i, FunctionRegistry functionRegistry) {
        try {
            int i2 = i;
            HashMap hashMap = new HashMap();
            for (FunctionCall functionCall : getFunctionList()) {
                KsqlAggregateFunction ksqlAggregateFunction = functionRegistry.getAggregateFunction(functionCall.getName().toString(), functionCall.getArguments(), schemaKStream.getSchema()).getInstance(map, functionCall.getArguments());
                int i3 = i2;
                i2++;
                hashMap.put(Integer.valueOf(i3), ksqlAggregateFunction);
                kudafInitializer.addAggregateIntializer(ksqlAggregateFunction.getInitialValueSupplier());
                schemaBuilder.field("AGG_COL_" + i2, ksqlAggregateFunction.getReturnType());
            }
            return hashMap;
        } catch (Exception e) {
            throw new KsqlException(String.format("Failed to create aggregate val to function map. expressionNames:%s", map), e);
        }
    }

    private Schema buildAggregateSchema(Schema schema, FunctionRegistry functionRegistry) {
        SchemaBuilder struct = SchemaBuilder.struct();
        List fields = schema.fields();
        for (int i = 0; i < getRequiredColumnList().size(); i++) {
            struct.field(((Field) fields.get(i)).name(), ((Field) fields.get(i)).schema());
        }
        for (int i2 = 0; i2 < getFunctionList().size(); i2++) {
            struct.field(AggregateExpressionRewriter.AGGREGATE_FUNCTION_VARIABLE_PREFIX + i2, functionRegistry.getAggregateFunction(getFunctionList().get(i2).getName().getSuffix(), getFunctionList().get(i2).getArguments(), schema).getReturnType());
        }
        return struct.build();
    }
}
