package org.apache.drill.exec.physical.impl.agg;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.exec.physical.base.PhysicalOperator;
import org.apache.drill.exec.physical.config.HashAggregate;
import org.apache.drill.exec.physical.impl.MockRecordBatch;
import org.apache.drill.exec.physical.rowSet.RowSet;
import org.apache.drill.exec.physical.rowSet.RowSetBuilder;
import org.apache.drill.exec.planner.physical.AggPrelBase;
import org.apache.drill.exec.record.metadata.SchemaBuilder;
import org.apache.drill.exec.record.metadata.TupleMetadata;
import org.apache.drill.test.PhysicalOpUnitTestBase;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/drill/exec/physical/impl/agg/TestHashAggBatch.class */
public class TestHashAggBatch extends PhysicalOpUnitTestBase {
    public static final String FIRST_NAME_COL = "firstname";
    public static final String LAST_NAME_COL = "lastname";
    public static final String STUFF_COL = "stuff";
    public static final List<String> FIRST_NAMES = ImmutableList.of("Strawberry", "Banana", "Mango", "Grape");
    public static final List<String> LAST_NAMES = ImmutableList.of("Red", "Green", "Blue", "Purple");
    public static final String TOTAL_STUFF_COL = "totalstuff";
    public static final TupleMetadata INT_OUTPUT_SCHEMA = new SchemaBuilder().add("firstname", TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED).add("lastname", TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED).add(TOTAL_STUFF_COL, TypeProtos.MinorType.BIGINT, TypeProtos.DataMode.OPTIONAL).buildSchema();

    @Before
    public void setupSimpleSingleBatchSumTestPhase1of2() {
        this.operatorFixture.getOptionManager().setLocalOption("exec.hashagg.num_partitions", 1L);
    }

    @Test
    public void simpleSingleBatchSumTestPhase1of2() throws Exception {
        batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_1of2);
    }

    @Test
    public void simpleMultiBatchSumTestPhase1of2() throws Exception {
        batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_1of2);
    }

    @Test
    public void simpleSingleBatchSumTestPhase1of1() throws Exception {
        batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_1of1);
    }

    @Test
    public void simpleMultiBatchSumTestPhase1of1() throws Exception {
        batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_1of1);
    }

    @Test
    public void simpleSingleBatchSumTestPhase2of2() throws Exception {
        batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_2of2);
    }

    @Test
    public void simpleMultiBatchSumTestPhase2of2() throws Exception {
        batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_2of2);
    }

    private void batchSumTest(int i, int i2, AggPrelBase.OperatorPhase operatorPhase) throws Exception {
        HashAggregate createHashAggPhysicalOperator = createHashAggPhysicalOperator(operatorPhase);
        List<RowSet> buildInputRowSets = buildInputRowSets(TypeProtos.MinorType.INT, TypeProtos.DataMode.REQUIRED, i, i2);
        MockRecordBatch.Builder builder = new MockRecordBatch.Builder();
        buildInputRowSets.forEach(rowSet -> {
            builder.sendData(rowSet);
        });
        MockRecordBatch build = builder.build(this.fragContext);
        opTestBuilder().physicalOperator(createHashAggPhysicalOperator).combineOutputBatches().unordered().addUpstreamBatch(build).addExpectedResult(buildIntExpectedRowSet(i)).go();
    }

    private HashAggregate createHashAggPhysicalOperator(AggPrelBase.OperatorPhase operatorPhase) {
        return new HashAggregate((PhysicalOperator) null, operatorPhase, Lists.newArrayList(new NamedExpression[]{new NamedExpression(SchemaPath.getSimplePath("firstname"), new FieldReference("firstname")), new NamedExpression(SchemaPath.getSimplePath("lastname"), new FieldReference("lastname"))}), Lists.newArrayList(new NamedExpression[]{new NamedExpression(new FunctionCall("sum", ImmutableList.of(SchemaPath.getSimplePath(STUFF_COL)), new ExpressionPosition((String) null, 0)), new FieldReference(TOTAL_STUFF_COL))}), 0.0f);
    }

    private TupleMetadata buildInputSchema(TypeProtos.MinorType minorType, TypeProtos.DataMode dataMode) {
        return new SchemaBuilder().add("firstname", TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED).add("lastname", TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED).add(STUFF_COL, minorType, dataMode).buildSchema();
    }

    private List<RowSet> buildInputRowSets(TypeProtos.MinorType minorType, TypeProtos.DataMode dataMode, int i, int i2) {
        Preconditions.checkArgument(i > 0);
        Preconditions.checkArgument(i2 > 0);
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        RowSetBuilder rowSetBuilder = null;
        int i4 = 1;
        for (int i5 = 0; i5 < FIRST_NAMES.size(); i5++) {
            String str = FIRST_NAMES.get(i5);
            int i6 = 0;
            while (i6 < LAST_NAMES.size()) {
                String str2 = LAST_NAMES.get(i6);
                for (int i7 = 1; i7 <= i; i7++) {
                    int i8 = i7 * i4;
                    if (i3 == 0) {
                        rowSetBuilder = new RowSetBuilder(this.operatorFixture.allocator(), buildInputSchema(minorType, dataMode));
                    }
                    rowSetBuilder.addRow(new Object[]{str, str2, Integer.valueOf(i8)});
                    i3++;
                    if (i3 == i2) {
                        arrayList.add(rowSetBuilder.build());
                        i3 = 0;
                    }
                }
                i6++;
                i4++;
            }
        }
        if (i3 != 0) {
            arrayList.add(rowSetBuilder.build());
        }
        return arrayList;
    }

    private RowSet buildIntExpectedRowSet(int i) {
        RowSetBuilder rowSetBuilder = new RowSetBuilder(this.operatorFixture.allocator(), INT_OUTPUT_SCHEMA);
        int i2 = 1;
        for (int i3 = 0; i3 < FIRST_NAMES.size(); i3++) {
            String str = FIRST_NAMES.get(i3);
            int i4 = 0;
            while (i4 < LAST_NAMES.size()) {
                rowSetBuilder.addRow(new Object[]{str, LAST_NAMES.get(i4), Long.valueOf(((i * (i + 1)) / 2) * i2)});
                i4++;
                i2++;
            }
        }
        return rowSetBuilder.build();
    }
}
