/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.state.api;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.state.api.SavepointReader;
import org.apache.flink.state.api.utils.JobResultRetriever;
import org.apache.flink.state.api.utils.SavepointTestBase;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction;
import org.apache.flink.streaming.api.graph.ExecutionPlan;
import org.apache.flink.test.util.AbstractTestBaseJUnit4;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

public abstract class SavepointReaderITTestBase
extends AbstractTestBaseJUnit4 {
    static final String UID = "stateful-operator";
    static final String LIST_NAME = "list";
    static final String UNION_NAME = "union";
    static final String BROADCAST_NAME = "broadcast";
    private final ListStateDescriptor<Integer> list;
    private final ListStateDescriptor<Integer> union;
    private final MapStateDescriptor<Integer, String> broadcast;

    SavepointReaderITTestBase(ListStateDescriptor<Integer> list, ListStateDescriptor<Integer> union, MapStateDescriptor<Integer, String> broadcast) {
        this.list = list;
        this.union = union;
        this.broadcast = broadcast;
    }

    @Test
    public void testOperatorStateInputFormat() throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        DataStream data = env.addSource((SourceFunction)new SavepointSource()).rebalance();
        StatefulOperator statefulOperator = new StatefulOperator(this.list, this.union, this.broadcast);
        data.connect(data.broadcast(new MapStateDescriptor[]{this.broadcast})).process((BroadcastProcessFunction)statefulOperator).uid(UID).sinkTo((Sink)new DiscardingSink());
        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
        String savepoint = this.takeSavepoint(jobGraph);
        this.verifyListState(savepoint, env);
        this.verifyUnionState(savepoint, env);
        this.verifyBroadcastState(savepoint, env);
    }

    abstract DataStream<Integer> readListState(SavepointReader var1) throws IOException;

    abstract DataStream<Integer> readUnionState(SavepointReader var1) throws IOException;

    abstract DataStream<Tuple2<Integer, String>> readBroadcastState(SavepointReader var1) throws IOException;

    private void verifyListState(String path, StreamExecutionEnvironment env) throws Exception {
        SavepointReader savepoint = SavepointReader.read((StreamExecutionEnvironment)env, (String)path, (StateBackend)new HashMapStateBackend());
        List<Integer> listResult = JobResultRetriever.collect(this.readListState(savepoint));
        listResult.sort(Comparator.naturalOrder());
        Assert.assertEquals((String)"Unexpected elements read from list state", SavepointSource.getElements(), listResult);
    }

    private void verifyUnionState(String path, StreamExecutionEnvironment env) throws Exception {
        SavepointReader savepoint = SavepointReader.read((StreamExecutionEnvironment)env, (String)path, (StateBackend)new HashMapStateBackend());
        List<Integer> unionResult = JobResultRetriever.collect(this.readUnionState(savepoint));
        unionResult.sort(Comparator.naturalOrder());
        Assert.assertEquals((String)"Unexpected elements read from union state", SavepointSource.getElements(), unionResult);
    }

    private void verifyBroadcastState(String path, StreamExecutionEnvironment env) throws Exception {
        SavepointReader savepoint = SavepointReader.read((StreamExecutionEnvironment)env, (String)path, (StateBackend)new HashMapStateBackend());
        List<Tuple2<Integer, String>> broadcastResult = JobResultRetriever.collect(this.readBroadcastState(savepoint));
        List broadcastStateKeys = broadcastResult.stream().map(entry -> (Integer)entry.f0).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
        List broadcastStateValues = broadcastResult.stream().map(entry -> (String)entry.f1).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
        Assert.assertEquals((String)"Unexpected element in broadcast state keys", SavepointSource.getElements(), broadcastStateKeys);
        Assert.assertEquals((String)"Unexpected element in broadcast state values", SavepointSource.getElements().stream().map(Object::toString).sorted().collect(Collectors.toList()), broadcastStateValues);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String takeSavepoint(JobGraph jobGraph) throws Exception {
        SavepointSource.initializeForTest();
        ClusterClient client = MINI_CLUSTER_RESOURCE.getClusterClient();
        JobID jobId = jobGraph.getJobID();
        Deadline deadline = Deadline.fromNow((Duration)Duration.ofMinutes(5L));
        String dirPath = this.getTempDirPath(new AbstractID().toHexString());
        try {
            JobID jobID = (JobID)client.submitJob((ExecutionPlan)jobGraph).get();
            SavepointTestBase.waitForAllRunningOrSomeTerminal(jobID, MINI_CLUSTER_RESOURCE);
            boolean finished = false;
            while (deadline.hasTimeLeft()) {
                if (SavepointSource.isFinished()) {
                    finished = true;
                    break;
                }
                try {
                    Thread.sleep(2L);
                }
                catch (InterruptedException ignored) {
                    Thread.currentThread().interrupt();
                }
            }
            if (!finished) {
                Assert.fail((String)"Failed to initialize state within deadline");
            }
            CompletableFuture path = client.triggerSavepoint(jobID, dirPath, SavepointFormatType.CANONICAL);
            String string = (String)path.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
            return string;
        }
        finally {
            client.cancel(jobId).get();
        }
    }

    private static class SavepointSource
    implements SourceFunction<Integer> {
        private static volatile boolean finished;
        private volatile boolean running = true;
        private static final Integer[] elements;

        private SavepointSource() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) {
            Object object = ctx.getCheckpointLock();
            synchronized (object) {
                for (Integer element : elements) {
                    ctx.collect((Object)element);
                }
                finished = true;
            }
            while (this.running) {
                try {
                    Thread.sleep(100L);
                }
                catch (InterruptedException interruptedException) {}
            }
        }

        public void cancel() {
            this.running = false;
        }

        private static void initializeForTest() {
            finished = false;
        }

        private static boolean isFinished() {
            return finished;
        }

        private static List<Integer> getElements() {
            return Arrays.asList(elements);
        }

        static {
            elements = new Integer[]{1, 2, 3};
        }
    }

    private static class StatefulOperator
    extends BroadcastProcessFunction<Integer, Integer, Void>
    implements CheckpointedFunction {
        private final ListStateDescriptor<Integer> list;
        private final ListStateDescriptor<Integer> union;
        private final MapStateDescriptor<Integer, String> broadcast;
        private List<Integer> elements;
        private ListState<Integer> listState;
        private ListState<Integer> unionState;

        private StatefulOperator(ListStateDescriptor<Integer> list, ListStateDescriptor<Integer> union, MapStateDescriptor<Integer, String> broadcast) {
            this.list = list;
            this.union = union;
            this.broadcast = broadcast;
        }

        public void open(OpenContext openContext) {
            this.elements = new ArrayList<Integer>();
        }

        public void processElement(Integer value, BroadcastProcessFunction.ReadOnlyContext ctx, Collector<Void> out) {
            this.elements.add(value);
        }

        public void processBroadcastElement(Integer value, BroadcastProcessFunction.Context ctx, Collector<Void> out) throws Exception {
            ctx.getBroadcastState(this.broadcast).put((Object)value, (Object)value.toString());
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.listState.update(this.elements);
            this.unionState.update(this.elements);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.listState = context.getOperatorStateStore().getListState(this.list);
            this.unionState = context.getOperatorStateStore().getUnionListState(this.union);
        }
    }
}

