/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.python.scalar;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.functions.python.PythonFunction;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions;
import org.apache.flink.table.runtime.operators.python.scalar.AbstractPythonScalarFunctionOperator;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.VarCharType;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public abstract class PythonScalarFunctionOperatorTestBase<IN, OUT, UDFIN> {
    @Test
    void testRetractionFieldKept() throws Exception {
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = this.getTestHarness(new Configuration());
        long initialTime = 0L;
        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
        testHarness.open();
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 0L), initialTime + 1L));
        testHarness.processElement(new StreamRecord(this.newRow(false, "c3", "c4", 1L), initialTime + 2L));
        testHarness.processElement(new StreamRecord(this.newRow(false, "c5", "c6", 2L), initialTime + 3L));
        testHarness.close();
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 0L)));
        expectedOutput.add(new StreamRecord(this.newRow(false, "c3", "c4", 1L)));
        expectedOutput.add(new StreamRecord(this.newRow(false, "c5", "c6", 2L)));
        this.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
    }

    @Test
    public void testFinishBundleTriggeredOnCheckpoint() throws Exception {
        Configuration conf = new Configuration();
        conf.set(PythonOptions.MAX_BUNDLE_SIZE, (Object)10);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = this.getTestHarness(conf);
        long initialTime = 0L;
        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
        testHarness.open();
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 0L), initialTime + 1L));
        testHarness.prepareSnapshotPreBarrier(0L);
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 0L)));
        this.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testFinishBundleTriggeredByCount() throws Exception {
        Configuration conf = new Configuration();
        conf.set(PythonOptions.MAX_BUNDLE_SIZE, (Object)2);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = this.getTestHarness(conf);
        long initialTime = 0L;
        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
        testHarness.open();
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 0L), initialTime + 1L));
        this.assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 1L), initialTime + 2L));
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 0L)));
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 1L)));
        this.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testFinishBundleTriggeredByTime() throws Exception {
        Configuration conf = new Configuration();
        conf.set(PythonOptions.MAX_BUNDLE_SIZE, (Object)10);
        conf.set(PythonOptions.MAX_BUNDLE_TIME_MILLS, (Object)1000L);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = this.getTestHarness(conf);
        long initialTime = 0L;
        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
        testHarness.open();
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 0L), initialTime + 1L));
        this.assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
        testHarness.setProcessingTime(1000L);
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 0L)));
        this.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testFinishBundleTriggeredByClose() throws Exception {
        Configuration conf = new Configuration();
        conf.set(PythonOptions.MAX_BUNDLE_SIZE, (Object)10);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = this.getTestHarness(conf);
        long initialTime = 0L;
        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
        testHarness.open();
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 0L), initialTime + 1L));
        this.assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
        testHarness.close();
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 0L)));
        this.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
    }

    @Test
    public void testWatermarkProcessedOnFinishBundle() throws Exception {
        Configuration conf = new Configuration();
        conf.set(PythonOptions.MAX_BUNDLE_SIZE, (Object)10);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = this.getTestHarness(conf);
        long initialTime = 0L;
        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
        testHarness.open();
        testHarness.processElement(new StreamRecord(this.newRow(true, "c1", "c2", 0L), initialTime + 1L));
        testHarness.processWatermark(initialTime + 2L);
        this.assertOutputEquals("Watermark has been processed", expectedOutput, testHarness.getOutput());
        testHarness.prepareSnapshotPreBarrier(0L);
        expectedOutput.add(new StreamRecord(this.newRow(true, "c1", "c2", 0L)));
        expectedOutput.add(new Watermark(initialTime + 2L));
        this.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testPythonScalarFunctionOperatorIsChainedByDefault() {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(1);
        StreamTableEnvironment tEnv = this.createTableEnvironment(env);
        tEnv.getConfig().set(TaskManagerOptions.TASK_OFF_HEAP_MEMORY, (Object)MemorySize.parse((String)"80mb"));
        tEnv.createTemporarySystemFunction("pyFunc", (UserDefinedFunction)new JavaUserDefinedScalarFunctions.PythonScalarFunction("pyFunc"));
        DataStreamSource ds = env.fromData((Object[])new Tuple2[]{new Tuple2((Object)1, (Object)2)});
        Table t = tEnv.fromDataStream((DataStream)ds, new Expression[]{Expressions.$((String)"a"), Expressions.$((String)"b")}).select(new Expression[]{Expressions.call((String)"pyFunc", (Object[])new Object[]{Expressions.$((String)"a"), Expressions.$((String)"b")})});
        tEnv.toDataStream(t, (AbstractDataType)DataTypes.INT());
        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
        List vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
        Assertions.assertThat((List)vertices).hasSize(1);
    }

    private OneInputStreamOperatorTestHarness<IN, OUT> getTestHarness(Configuration config) throws Exception {
        RowType dataType = new RowType(Arrays.asList(new RowType.RowField("f1", (LogicalType)new VarCharType()), new RowType.RowField("f2", (LogicalType)new VarCharType()), new RowType.RowField("f3", (LogicalType)new BigIntType())));
        AbstractPythonScalarFunctionOperator operator = this.getTestOperator(config, new PythonFunctionInfo[]{new PythonFunctionInfo(DummyPythonFunction.INSTANCE, (Object[])new Integer[]{0})}, dataType, dataType, new int[]{2}, new int[]{0, 1});
        OneInputStreamOperatorTestHarness testHarness = new OneInputStreamOperatorTestHarness((OneInputStreamOperator)operator);
        testHarness.getStreamConfig().setManagedMemoryFractionOperatorOfUseCase(ManagedMemoryUseCase.PYTHON, 0.5);
        testHarness.setup(this.getOutputTypeSerializer(dataType));
        return testHarness;
    }

    public abstract AbstractPythonScalarFunctionOperator getTestOperator(Configuration var1, PythonFunctionInfo[] var2, RowType var3, RowType var4, int[] var5, int[] var6);

    public abstract IN newRow(boolean var1, Object ... var2);

    public abstract void assertOutputEquals(String var1, Collection<Object> var2, Collection<Object> var3);

    public abstract StreamTableEnvironment createTableEnvironment(StreamExecutionEnvironment var1);

    public abstract TypeSerializer<OUT> getOutputTypeSerializer(RowType var1);

    public static class DummyPythonFunction
    implements PythonFunction {
        private static final long serialVersionUID = 1L;
        public static final PythonFunction INSTANCE = new DummyPythonFunction();

        public byte[] getSerializedPythonFunction() {
            return new byte[0];
        }

        public PythonEnv getPythonEnv() {
            return new PythonEnv(PythonEnv.ExecType.PROCESS);
        }
    }
}

