/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.state.api;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.state.api.OperatorIdentifier;
import org.apache.flink.state.api.OperatorTransformation;
import org.apache.flink.state.api.SavepointWriter;
import org.apache.flink.state.api.StateBootstrapTransformation;
import org.apache.flink.state.api.functions.BroadcastStateBootstrapFunction;
import org.apache.flink.state.api.functions.KeyedStateBootstrapFunction;
import org.apache.flink.state.api.functions.StateBootstrapFunction;
import org.apache.flink.state.rocksdb.EmbeddedRocksDBStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.test.util.AbstractTestBaseJUnit4;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Collector;
import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Test;

public class SavepointWriterITCase
extends AbstractTestBaseJUnit4 {
    private static final String ACCOUNT_UID = "accounts";
    private static final String CURRENCY_UID = "currency";
    private static final String MODIFY_UID = "numbers";
    private static final MapStateDescriptor<String, Double> descriptor = new MapStateDescriptor("currency-rate", Types.STRING, Types.DOUBLE);
    private static final Collection<Account> accounts = Arrays.asList(new Account(1, 100.0), new Account(2, 100.0), new Account(3, 100.0));
    private static final Collection<CurrencyRate> currencyRates = Arrays.asList(new CurrencyRate("USD", 1.0), new CurrencyRate("EUR", 1.3));

    @Test
    public void testDefaultStateBackend() throws Exception {
        this.testStateBootstrapAndModification(new Configuration(), null);
    }

    @Test
    public void testHashMapStateBackend() throws Exception {
        this.testStateBootstrapAndModification(new Configuration().set(StateBackendOptions.STATE_BACKEND, (Object)"hashmap"), (StateBackend)new HashMapStateBackend());
    }

    @Test
    public void testEmbeddedRocksDBStateBackend() throws Exception {
        this.testStateBootstrapAndModification(new Configuration().set(StateBackendOptions.STATE_BACKEND, (Object)"rocksdb"), (StateBackend)new EmbeddedRocksDBStateBackend());
    }

    public void testStateBootstrapAndModification(Configuration config, StateBackend backend) throws Exception {
        String savepointPath = this.getTempDirPath(new AbstractID().toHexString());
        this.bootstrapState(backend, savepointPath);
        this.validateBootstrap(config, savepointPath);
        String modifyPath = this.getTempDirPath(new AbstractID().toHexString());
        this.modifySavepoint(backend, savepointPath, modifyPath);
        this.validateModification(config, modifyPath);
    }

    private void bootstrapState(StateBackend backend, String savepointPath) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setRuntimeMode(RuntimeExecutionMode.AUTOMATIC);
        StateBootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataStream)env.fromData(accounts)).keyBy((KeySelector & Serializable)acc -> acc.id).transform((KeyedStateBootstrapFunction)new AccountBootstrapper());
        StateBootstrapTransformation broadcastTransformation = OperatorTransformation.bootstrapWith((DataStream)env.fromData(currencyRates)).transform((BroadcastStateBootstrapFunction)new CurrencyBootstrapFunction());
        SavepointWriter writer = backend == null ? SavepointWriter.newSavepoint((StreamExecutionEnvironment)env, (int)128) : SavepointWriter.newSavepoint((StreamExecutionEnvironment)env, (StateBackend)backend, (int)128);
        writer.withOperator(OperatorIdentifier.forUid((String)ACCOUNT_UID), transformation).withOperator(SavepointWriterITCase.getUidHashFromUid(CURRENCY_UID), broadcastTransformation).write(savepointPath);
        env.execute("Bootstrap");
    }

    private void validateBootstrap(Configuration configuration, String savepointPath) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment((Configuration)configuration);
        SingleOutputStreamOperator stream = env.fromData(accounts).keyBy((KeySelector & Serializable)acc -> acc.id).flatMap((FlatMapFunction)new UpdateAndGetAccount()).uid(ACCOUNT_UID);
        CloseableIterator results = stream.collectAsync();
        env.fromData(currencyRates).connect(env.fromData(currencyRates).broadcast(new MapStateDescriptor[]{descriptor})).process((BroadcastProcessFunction)new CurrencyValidationFunction()).uid(CURRENCY_UID).sinkTo((Sink)new DiscardingSink());
        StreamGraph streamGraph = env.getStreamGraph();
        streamGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath((String)savepointPath, (boolean)false));
        env.execute(streamGraph);
        Assertions.assertThat((Iterator)results).toIterable().hasSize(3);
        results.close();
    }

    private void modifySavepoint(StateBackend backend, String savepointPath, String modifyPath) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setRuntimeMode(RuntimeExecutionMode.AUTOMATIC);
        StateBootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataStream)env.fromData((Object[])new Integer[]{1, 2, 3})).transform((StateBootstrapFunction)new ModifyProcessFunction());
        SavepointWriter writer = backend == null ? SavepointWriter.fromExistingSavepoint((StreamExecutionEnvironment)env, (String)savepointPath) : SavepointWriter.fromExistingSavepoint((StreamExecutionEnvironment)env, (String)savepointPath, (StateBackend)backend);
        writer.removeOperator(OperatorIdentifier.forUid((String)CURRENCY_UID)).withOperator(SavepointWriterITCase.getUidHashFromUid(MODIFY_UID), transformation).write(modifyPath);
        env.execute("Modifying");
    }

    private void validateModification(Configuration configuration, String savepointPath) throws Exception {
        StreamExecutionEnvironment sEnv = StreamExecutionEnvironment.getExecutionEnvironment((Configuration)configuration);
        SingleOutputStreamOperator stream = sEnv.fromData(accounts).keyBy((KeySelector & Serializable)acc -> acc.id).flatMap((FlatMapFunction)new UpdateAndGetAccount()).uid(ACCOUNT_UID);
        CloseableIterator results = stream.collectAsync();
        stream.map((MapFunction & Serializable)acc -> acc.id).map((MapFunction)new StatefulOperator()).uid(MODIFY_UID).sinkTo((Sink)new DiscardingSink());
        StreamGraph streamGraph = sEnv.getStreamGraph();
        streamGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath((String)savepointPath, (boolean)false));
        sEnv.execute(streamGraph);
        Assertions.assertThat((Iterator)results).toIterable().hasSize(3);
        results.close();
    }

    private static OperatorIdentifier getUidHashFromUid(String uid) {
        return OperatorIdentifier.forUidHash((String)OperatorIdentifier.forUid((String)uid).getOperatorId().toHexString());
    }

    public static class AccountBootstrapper
    extends KeyedStateBootstrapFunction<Integer, Account> {
        ValueState<Double> state;

        public void open(OpenContext openContext) {
            ValueStateDescriptor descriptor = new ValueStateDescriptor("total", Types.DOUBLE);
            this.state = this.getRuntimeContext().getState(descriptor);
        }

        public void processElement(Account value, KeyedStateBootstrapFunction.Context ctx) throws Exception {
            this.state.update((Object)value.amount);
        }
    }

    public static class CurrencyBootstrapFunction
    extends BroadcastStateBootstrapFunction<CurrencyRate> {
        public void processElement(CurrencyRate value, BroadcastStateBootstrapFunction.Context ctx) throws Exception {
            ctx.getBroadcastState(descriptor).put((Object)value.currency, (Object)value.rate);
        }
    }

    public static class UpdateAndGetAccount
    extends RichFlatMapFunction<Account, Account> {
        ValueState<Double> state;

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
            ValueStateDescriptor descriptor = new ValueStateDescriptor("total", Types.DOUBLE);
            this.state = this.getRuntimeContext().getState(descriptor);
        }

        public void flatMap(Account value, Collector<Account> out) throws Exception {
            Double current = (Double)this.state.value();
            if (current != null) {
                value.amount += current.doubleValue();
            }
            this.state.update((Object)value.amount);
            out.collect((Object)value);
        }
    }

    public static class CurrencyValidationFunction
    extends BroadcastProcessFunction<CurrencyRate, CurrencyRate, Void> {
        public void processElement(CurrencyRate value, BroadcastProcessFunction.ReadOnlyContext ctx, Collector<Void> out) throws Exception {
            Assert.assertEquals((String)"Incorrect currency rate", (double)value.rate, (double)((Double)ctx.getBroadcastState(descriptor).get((Object)value.currency)), (double)1.0E-4);
        }

        public void processBroadcastElement(CurrencyRate value, BroadcastProcessFunction.Context ctx, Collector<Void> out) {
        }
    }

    public static class ModifyProcessFunction
    extends StateBootstrapFunction<Integer> {
        List<Integer> numbers;
        ListState<Integer> state;

        public void open(OpenContext openContext) {
            this.numbers = new ArrayList<Integer>();
        }

        public void processElement(Integer value, StateBootstrapFunction.Context ctx) {
            this.numbers.add(value);
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.state.update(this.numbers);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.state = context.getOperatorStateStore().getUnionListState(new ListStateDescriptor(SavepointWriterITCase.MODIFY_UID, Types.INT));
        }
    }

    public static class StatefulOperator
    extends RichMapFunction<Integer, Integer>
    implements CheckpointedFunction {
        List<Integer> numbers;
        ListState<Integer> state;

        public void open(OpenContext openContext) {
            this.numbers = new ArrayList<Integer>();
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.state.update(this.numbers);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.state = context.getOperatorStateStore().getUnionListState(new ListStateDescriptor(SavepointWriterITCase.MODIFY_UID, Types.INT));
            if (context.isRestored()) {
                HashSet<Integer> expected = new HashSet<Integer>();
                expected.add(1);
                expected.add(2);
                expected.add(3);
                for (Integer number : (Iterable)this.state.get()) {
                    Assert.assertTrue((String)"Duplicate state", (boolean)expected.contains(number));
                    expected.remove(number);
                }
                Assert.assertTrue((String)("Failed to bootstrap all state elements: " + Arrays.toString(expected.toArray())), (boolean)expected.isEmpty());
            }
        }

        public Integer map(Integer value) {
            return null;
        }
    }

    public static class Account {
        public int id;
        public double amount;
        public long timestamp;

        Account(int id, double amount) {
            this.id = id;
            this.amount = amount;
            this.timestamp = 1000L;
        }

        public boolean equals(Object obj) {
            return obj instanceof Account && ((Account)obj).id == this.id && ((Account)obj).amount == this.amount;
        }

        public int hashCode() {
            return Objects.hash(this.id, this.amount);
        }
    }

    public static class CurrencyRate {
        public String currency;
        public Double rate;

        CurrencyRate(String currency, double rate) {
            this.currency = currency;
            this.rate = rate;
        }

        public boolean equals(Object obj) {
            return obj instanceof CurrencyRate && ((CurrencyRate)obj).currency.equals(this.currency) && ((CurrencyRate)obj).rate.equals(this.rate);
        }

        public int hashCode() {
            return Objects.hash(this.currency, this.rate);
        }
    }
}

