/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.adaptive;

import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.planner.adaptive.AdaptiveJoinOperatorGenerator;
import org.apache.flink.table.planner.plan.utils.OperatorType;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.HashJoinOperator;
import org.apache.flink.table.runtime.operators.join.Int2HashJoinOperatorTestBase;
import org.apache.flink.table.runtime.operators.join.SortMergeJoinOperator;
import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
import org.apache.flink.table.runtime.util.JoinUtil;
import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.MutableObjectIterator;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class AdaptiveJoinOperatorGeneratorTest
extends Int2HashJoinOperatorTestBase {
    AdaptiveJoinOperatorGeneratorTest() {
    }

    @Test
    void testShuffleHashJoinTransformationCorrectness() throws Exception {
        this.testInnerJoin(true, OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testInnerJoin(false, OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testLeftOutJoin(true, OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testLeftOutJoin(false, OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testRightOutJoin(true, OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testRightOutJoin(false, OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testSemiJoin(OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testAntiJoin(OperatorType.ShuffleHashJoin, false, OperatorType.ShuffleHashJoin);
        this.testInnerJoin(true, OperatorType.ShuffleHashJoin, true, OperatorType.BroadcastHashJoin);
        this.testInnerJoin(false, OperatorType.ShuffleHashJoin, true, OperatorType.BroadcastHashJoin);
        this.testLeftOutJoin(false, OperatorType.ShuffleHashJoin, true, OperatorType.BroadcastHashJoin);
        this.testRightOutJoin(true, OperatorType.ShuffleHashJoin, true, OperatorType.BroadcastHashJoin);
        this.testSemiJoin(OperatorType.ShuffleHashJoin, true, OperatorType.BroadcastHashJoin);
        this.testAntiJoin(OperatorType.ShuffleHashJoin, true, OperatorType.BroadcastHashJoin);
    }

    @Test
    void testSortMergeJoinTransformationCorrectness() throws Exception {
        this.testInnerJoin(true, OperatorType.SortMergeJoin, false, OperatorType.SortMergeJoin);
        this.testInnerJoin(false, OperatorType.SortMergeJoin, false, OperatorType.SortMergeJoin);
        this.testLeftOutJoin(true, OperatorType.SortMergeJoin, false, OperatorType.SortMergeJoin);
        this.testRightOutJoin(true, OperatorType.SortMergeJoin, false, OperatorType.SortMergeJoin);
        this.testAntiJoin(OperatorType.SortMergeJoin, false, OperatorType.SortMergeJoin);
        this.testAntiJoin(OperatorType.SortMergeJoin, false, OperatorType.SortMergeJoin);
        this.testInnerJoin(true, OperatorType.SortMergeJoin, true, OperatorType.BroadcastHashJoin);
        this.testInnerJoin(false, OperatorType.SortMergeJoin, true, OperatorType.BroadcastHashJoin);
        this.testLeftOutJoin(false, OperatorType.SortMergeJoin, true, OperatorType.BroadcastHashJoin);
        this.testRightOutJoin(true, OperatorType.SortMergeJoin, true, OperatorType.BroadcastHashJoin);
        this.testSemiJoin(OperatorType.SortMergeJoin, true, OperatorType.BroadcastHashJoin);
        this.testAntiJoin(OperatorType.SortMergeJoin, true, OperatorType.BroadcastHashJoin);
    }

    private void testInnerJoin(boolean isBuildLeft, OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) throws Exception {
        int numKeys = 100;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true);
        this.buildJoin((MutableObjectIterator<BinaryRowData>)buildInput, (MutableObjectIterator<BinaryRowData>)probeInput, originalJoinType, expectedOperatorType, false, false, isBuildLeft, isBroadcast, numKeys * buildValsPerKey * probeValsPerKey, numKeys, 165);
    }

    private void testLeftOutJoin(boolean isBuildLeft, OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) throws Exception {
        int numKeys1 = 9;
        int numKeys2 = 10;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(isBuildLeft ? numKeys1 : numKeys2, buildValsPerKey, true);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(isBuildLeft ? numKeys2 : numKeys1, probeValsPerKey, true);
        this.buildJoin((MutableObjectIterator<BinaryRowData>)buildInput, (MutableObjectIterator<BinaryRowData>)probeInput, originalJoinType, expectedOperatorType, true, false, isBuildLeft, isBroadcast, numKeys1 * buildValsPerKey * probeValsPerKey, numKeys1, 165);
    }

    private void testRightOutJoin(boolean isBuildLeft, OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) throws Exception {
        int numKeys1 = 9;
        int numKeys2 = 10;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
        this.buildJoin((MutableObjectIterator<BinaryRowData>)buildInput, (MutableObjectIterator<BinaryRowData>)probeInput, originalJoinType, expectedOperatorType, false, true, isBuildLeft, isBroadcast, isBuildLeft ? 280 : 270, numKeys2, -1);
    }

    private void testSemiJoin(OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) throws Exception {
        int numKeys1 = 9;
        int numKeys2 = 10;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        if (originalJoinType == OperatorType.SortMergeJoin && !isBroadcast) {
            numKeys1 = 10;
            numKeys2 = 9;
            buildValsPerKey = 10;
            probeValsPerKey = 3;
        }
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
        Object operator = this.newOperator(FlinkJoinType.SEMI, false, isBroadcast, originalJoinType);
        this.assertOperatorType(operator, expectedOperatorType);
        AdaptiveJoinOperatorGeneratorTest.joinAndAssert((Object)operator, (MutableObjectIterator)buildInput, (MutableObjectIterator)probeInput, (int)90, (int)9, (int)45, (boolean)true);
    }

    private void testAntiJoin(OperatorType originalJoinType, boolean isBroadcast, OperatorType expectedOperatorType) throws Exception {
        int numKeys1 = 9;
        int numKeys2 = 10;
        int buildValsPerKey = 3;
        int probeValsPerKey = 10;
        if (originalJoinType == OperatorType.SortMergeJoin && !isBroadcast) {
            numKeys1 = 10;
            numKeys2 = 9;
            buildValsPerKey = 10;
            probeValsPerKey = 3;
        }
        UniformBinaryRowGenerator buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true);
        UniformBinaryRowGenerator probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true);
        Object operator = this.newOperator(FlinkJoinType.ANTI, false, isBroadcast, originalJoinType);
        this.assertOperatorType(operator, expectedOperatorType);
        AdaptiveJoinOperatorGeneratorTest.joinAndAssert((Object)operator, (MutableObjectIterator)buildInput, (MutableObjectIterator)probeInput, (int)10, (int)1, (int)45, (boolean)true);
    }

    public void buildJoin(MutableObjectIterator<BinaryRowData> buildInput, MutableObjectIterator<BinaryRowData> probeInput, OperatorType originalJoinType, OperatorType expectedOperatorType, boolean leftOut, boolean rightOut, boolean buildLeft, boolean isBroadcast, int expectOutSize, int expectOutKeySize, int expectOutVal) throws Exception {
        FlinkJoinType flinkJoinType = JoinUtil.getJoinType((boolean)leftOut, (boolean)rightOut);
        Object operator = this.newOperator(flinkJoinType, buildLeft, isBroadcast, originalJoinType);
        this.assertOperatorType(operator, expectedOperatorType);
        AdaptiveJoinOperatorGeneratorTest.joinAndAssert((Object)operator, buildInput, probeInput, (int)expectOutSize, (int)expectOutKeySize, (int)expectOutVal, (boolean)false);
    }

    public Object newOperator(FlinkJoinType flinkJoinType, boolean buildLeft, boolean isBroadcast, OperatorType operatorType) {
        AdaptiveJoin adaptiveJoin = this.genAdaptiveJoin(flinkJoinType, operatorType);
        adaptiveJoin.markAsBroadcastJoin(isBroadcast, buildLeft);
        return adaptiveJoin.genOperatorFactory(((Object)((Object)this)).getClass().getClassLoader(), (ReadableConfig)new Configuration());
    }

    public void assertOperatorType(Object operator, OperatorType expectedOperatorType) {
        switch (expectedOperatorType) {
            case BroadcastHashJoin: 
            case ShuffleHashJoin: {
                if (operator instanceof CodeGenOperatorFactory) {
                    Assertions.assertThat((String)((CodeGenOperatorFactory)operator).getGeneratedClass().getClassName()).contains(new CharSequence[]{"LongHashJoinOperator"});
                    break;
                }
                Assertions.assertThat((Object)operator).isInstanceOf(SimpleOperatorFactory.class);
                Assertions.assertThat((Object)((SimpleOperatorFactory)operator).getOperator()).isInstanceOf(HashJoinOperator.class);
                break;
            }
            case SortMergeJoin: {
                Assertions.assertThat((Object)operator).isInstanceOf(SimpleOperatorFactory.class);
                Assertions.assertThat((Object)((SimpleOperatorFactory)operator).getOperator()).isInstanceOf(SortMergeJoinOperator.class);
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format("Unexpected operator type %s.", expectedOperatorType));
            }
        }
    }

    public AdaptiveJoin genAdaptiveJoin(FlinkJoinType flinkJoinType, OperatorType operatorType) {
        GeneratedJoinCondition condFuncCode = new GeneratedJoinCondition(Int2HashJoinOperatorTestBase.MyJoinCondition.class.getCanonicalName(), "", new Object[0]){

            public JoinCondition newInstance(ClassLoader classLoader) {
                return new Int2HashJoinOperatorTestBase.MyJoinCondition(new Object[0]);
            }
        };
        return new AdaptiveJoinOperatorGenerator(new int[]{0}, new int[]{0}, flinkJoinType, new boolean[]{true}, RowType.of((LogicalType[])new LogicalType[]{new IntType(), new IntType()}), RowType.of((LogicalType[])new LogicalType[]{new IntType(), new IntType()}), condFuncCode, 20, 10000, 20L, 10000L, false, ((MemorySize)ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_JOIN_MEMORY.defaultValue()).getBytes(), true, operatorType);
    }
}

