/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.core.testutils.FlinkAssertions;
import org.apache.flink.core.testutils.FlinkCompletableFutureAssert;
import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.CheckpointPlan;
import org.apache.flink.runtime.checkpoint.CheckpointPlanCalculatorContext;
import org.apache.flink.runtime.checkpoint.DefaultCheckpointPlanCalculator;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphCheckpointPlanCalculatorContext;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.runtime.util.JobVertexConnectionUtils;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.AbstractCollectionAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class DefaultCheckpointPlanCalculatorTest {
    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION = TestingUtils.defaultExecutorExtension();

    DefaultCheckpointPlanCalculatorTest() {
    }

    @Test
    void testComputeAllRunningGraph() throws Exception {
        this.runSingleTest(Arrays.asList(new VertexDeclaration(3, Collections.emptySet()), new VertexDeclaration(4, Collections.emptySet()), new VertexDeclaration(5, Collections.emptySet()), new VertexDeclaration(6, Collections.emptySet())), Arrays.asList(new EdgeDeclaration(0, 2, DistributionPattern.ALL_TO_ALL), new EdgeDeclaration(1, 2, DistributionPattern.POINTWISE), new EdgeDeclaration(2, 3, DistributionPattern.ALL_TO_ALL)), Arrays.asList(new TaskDeclaration(0, this.range(0, 3)), new TaskDeclaration(1, this.range(0, 4))));
    }

    @Test
    void testAllToAllEdgeWithSomeSourcesFinished() throws Exception {
        this.runSingleTest(Arrays.asList(new VertexDeclaration(3, this.range(0, 2)), new VertexDeclaration(4, Collections.emptySet())), Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.ALL_TO_ALL)), Collections.singletonList(new TaskDeclaration(0, this.range(2, 3))));
    }

    @Test
    void testOneToOneEdgeWithSomeSourcesFinished() throws Exception {
        this.runSingleTest(Arrays.asList(new VertexDeclaration(4, this.range(0, 2)), new VertexDeclaration(4, Collections.emptySet())), Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.POINTWISE)), Arrays.asList(new TaskDeclaration(0, this.range(2, 4)), new TaskDeclaration(1, this.range(0, 2))));
    }

    @Test
    void testOneToOnEdgeWithSomeSourcesAndTargetsFinished() throws Exception {
        this.runSingleTest(Arrays.asList(new VertexDeclaration(4, this.range(0, 2)), new VertexDeclaration(4, this.of(0))), Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.POINTWISE)), Arrays.asList(new TaskDeclaration(0, this.range(2, 4)), new TaskDeclaration(1, this.range(1, 2))));
    }

    @Test
    void testComputeWithMultipleInputs() throws Exception {
        this.runSingleTest(Arrays.asList(new VertexDeclaration(3, this.range(0, 3)), new VertexDeclaration(5, this.of(0, 2, 3)), new VertexDeclaration(5, this.of(2, 4)), new VertexDeclaration(5, this.of(2))), Arrays.asList(new EdgeDeclaration(0, 3, DistributionPattern.ALL_TO_ALL), new EdgeDeclaration(1, 3, DistributionPattern.POINTWISE), new EdgeDeclaration(2, 3, DistributionPattern.POINTWISE)), Arrays.asList(new TaskDeclaration(1, this.of(1, 4)), new TaskDeclaration(2, this.of(0, 1, 3))));
    }

    @Test
    void testComputeWithMultipleLevels() throws Exception {
        this.runSingleTest(Arrays.asList(new VertexDeclaration(16, this.range(0, 4)), new VertexDeclaration(16, this.range(0, 16)), new VertexDeclaration(16, this.range(0, 2)), new VertexDeclaration(16, Collections.emptySet()), new VertexDeclaration(16, Collections.emptySet())), Arrays.asList(new EdgeDeclaration(0, 2, DistributionPattern.POINTWISE), new EdgeDeclaration(0, 3, DistributionPattern.POINTWISE), new EdgeDeclaration(1, 2, DistributionPattern.ALL_TO_ALL), new EdgeDeclaration(1, 3, DistributionPattern.POINTWISE), new EdgeDeclaration(2, 4, DistributionPattern.POINTWISE), new EdgeDeclaration(3, 4, DistributionPattern.ALL_TO_ALL)), Arrays.asList(new TaskDeclaration(0, this.range(4, 16)), new TaskDeclaration(2, this.range(2, 4)), new TaskDeclaration(3, this.range(0, 4))));
    }

    @Test
    void testPlanCalculationWhenOneTaskNotRunning() throws Exception {
        this.runWithNotRunningTask(true, true);
        this.runWithNotRunningTask(true, false);
        this.runWithNotRunningTask(false, false);
        this.runWithNotRunningTask(false, true);
    }

    private void runWithNotRunningTask(boolean isRunningVertexSource, boolean isNotRunningVertexSource) throws Exception {
        for (ExecutionState notRunningState : EnumSet.complementOf(EnumSet.of(ExecutionState.RUNNING))) {
            JobVertexID runningVertex = new JobVertexID();
            JobVertexID notRunningVertex = new JobVertexID();
            ExecutionGraph graph = new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder().addJobVertex(runningVertex, isRunningVertexSource).addJobVertex(notRunningVertex, isNotRunningVertexSource).setTransitToRunning(false).build((ScheduledExecutorService)EXECUTOR_EXTENSION.getExecutor());
            this.transitVertexToState(graph, runningVertex, ExecutionState.RUNNING);
            this.transitVertexToState(graph, notRunningVertex, notRunningState);
            DefaultCheckpointPlanCalculator checkpointPlanCalculator = this.createCheckpointPlanCalculator(graph);
            ((FlinkCompletableFutureAssert)FlinkAssertions.assertThatFuture((CompletableFuture)checkpointPlanCalculator.calculateCheckpointPlan()).withFailMessage("The computation should fail since some tasks to trigger are in %s state", new Object[]{notRunningState})).eventuallyFailsWith(ExecutionException.class).havingCause().isInstanceOfSatisfying(CheckpointException.class, e -> Assertions.assertThat((Comparable)e.getCheckpointFailureReason()).isEqualTo((Object)CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING));
        }
    }

    private void transitVertexToState(ExecutionGraph graph, JobVertexID jobVertexID, ExecutionState state) {
        Arrays.stream(graph.getJobVertex(jobVertexID).getTaskVertices()).filter(vertex -> vertex.getJobvertexId().equals((Object)jobVertexID)).findFirst().get().getCurrentExecutionAttempt().transitionState(state);
    }

    private void runSingleTest(List<VertexDeclaration> vertexDeclarations, List<EdgeDeclaration> edgeDeclarations, List<TaskDeclaration> expectedToTriggerTaskDeclarations) throws Exception {
        this.runSingleTest(vertexDeclarations, edgeDeclarations, expectedToTriggerTaskDeclarations, IntStream.range(0, vertexDeclarations.size()).mapToObj(i -> new TaskDeclaration(i, ((VertexDeclaration)vertexDeclarations.get((int)i)).finishedSubtaskIndices)).collect(Collectors.toList()));
    }

    private void runSingleTest(List<VertexDeclaration> vertexDeclarations, List<EdgeDeclaration> edgeDeclarations, List<TaskDeclaration> expectedToTriggerTaskDeclarations, List<TaskDeclaration> expectedFinishedTaskDeclarations) throws Exception {
        ExecutionGraph graph = this.createExecutionGraph(vertexDeclarations, edgeDeclarations);
        DefaultCheckpointPlanCalculator planCalculator = this.createCheckpointPlanCalculator(graph);
        ArrayList expectedRunningTaskDeclarations = new ArrayList();
        ArrayList<ExecutionJobVertex> expectedFullyFinishedJobVertices = new ArrayList<ExecutionJobVertex>();
        expectedFinishedTaskDeclarations.forEach(finishedDeclaration -> {
            ExecutionJobVertex jobVertex = this.chooseJobVertex(graph, finishedDeclaration.vertexIndex);
            expectedRunningTaskDeclarations.add(new TaskDeclaration(finishedDeclaration.vertexIndex, this.minus(this.range(0, jobVertex.getParallelism()), finishedDeclaration.subtaskIndices)));
            if (finishedDeclaration.subtaskIndices.size() == jobVertex.getParallelism()) {
                expectedFullyFinishedJobVertices.add(jobVertex);
            }
        });
        List<ExecutionVertex> expectedRunningTasks = this.chooseTasks(graph, expectedRunningTaskDeclarations.toArray(new TaskDeclaration[0]));
        List<Execution> expectedFinishedTasks = this.chooseTasks(graph, expectedFinishedTaskDeclarations.toArray(new TaskDeclaration[0])).stream().map(ExecutionVertex::getCurrentExecutionAttempt).collect(Collectors.toList());
        List<ExecutionVertex> expectedToTriggerTasks = this.chooseTasks(graph, expectedToTriggerTaskDeclarations.toArray(new TaskDeclaration[0]));
        CheckpointPlan checkpointPlan = (CheckpointPlan)planCalculator.calculateCheckpointPlan().get();
        this.checkCheckpointPlan(expectedToTriggerTasks, expectedRunningTasks, expectedFinishedTasks, expectedFullyFinishedJobVertices, checkpointPlan);
    }

    private ExecutionGraph createExecutionGraph(List<VertexDeclaration> vertexDeclarations, List<EdgeDeclaration> edgeDeclarations) throws Exception {
        JobVertex[] jobVertices = new JobVertex[vertexDeclarations.size()];
        for (int i = 0; i < vertexDeclarations.size(); ++i) {
            jobVertices[i] = ExecutionGraphTestUtils.createJobVertex(this.vertexName(i), vertexDeclarations.get((int)i).parallelism, NoOpInvokable.class);
        }
        for (EdgeDeclaration edgeDeclaration : edgeDeclarations) {
            JobVertexConnectionUtils.connectNewDataSetAsInput(jobVertices[edgeDeclaration.target], jobVertices[edgeDeclaration.source], edgeDeclaration.distributionPattern, ResultPartitionType.PIPELINED);
        }
        DefaultExecutionGraph graph = ExecutionGraphTestUtils.createExecutionGraph((ScheduledExecutorService)EXECUTOR_EXTENSION.getExecutor(), jobVertices);
        graph.start(ComponentMainThreadExecutorServiceAdapter.forMainThread());
        graph.transitionToRunning();
        graph.getAllExecutionVertices().forEach(task -> task.getCurrentExecutionAttempt().transitionState(ExecutionState.RUNNING));
        for (int i = 0; i < vertexDeclarations.size(); ++i) {
            JobVertexID jobVertexId = jobVertices[i].getID();
            vertexDeclarations.get((int)i).finishedSubtaskIndices.forEach(arg_0 -> DefaultCheckpointPlanCalculatorTest.lambda$createExecutionGraph$5((ExecutionGraph)graph, jobVertexId, arg_0));
        }
        return graph;
    }

    private DefaultCheckpointPlanCalculator createCheckpointPlanCalculator(ExecutionGraph graph) {
        return new DefaultCheckpointPlanCalculator(graph.getJobID(), (CheckpointPlanCalculatorContext)new ExecutionGraphCheckpointPlanCalculatorContext(graph), graph.getVerticesTopologically(), true);
    }

    private void checkCheckpointPlan(List<ExecutionVertex> expectedToTrigger, List<ExecutionVertex> expectedRunning, List<Execution> expectedFinished, List<ExecutionJobVertex> expectedFullyFinished, CheckpointPlan plan) {
        List expectedTriggeredExecutions = expectedToTrigger.stream().map(ExecutionVertex::getCurrentExecutionAttempt).collect(Collectors.toList());
        this.assertSameInstancesWithoutOrder("The computed tasks to trigger is different from expected", expectedTriggeredExecutions, plan.getTasksToTrigger());
        this.assertSameInstancesWithoutOrder("The computed running tasks is different from expected", expectedRunning, plan.getTasksToCommitTo());
        this.assertSameInstancesWithoutOrder("The computed finished tasks is different from expected", expectedFinished, plan.getFinishedTasks());
        this.assertSameInstancesWithoutOrder("The computed fully finished JobVertex is different from expected", expectedFullyFinished, plan.getFullyFinishedJobVertex());
        this.assertSameInstancesWithoutOrder("The computed tasks to ack is different from expected", expectedRunning.stream().map(ExecutionVertex::getCurrentExecutionAttempt).collect(Collectors.toList()), plan.getTasksToWaitFor());
    }

    private <T> void assertSameInstancesWithoutOrder(String comment, Collection<T> expected, Collection<T> actual) {
        ((AbstractCollectionAssert)Assertions.assertThat(expected).as(comment, new Object[0])).containsExactlyInAnyOrderElementsOf(actual);
    }

    private List<ExecutionVertex> chooseTasks(ExecutionGraph graph, TaskDeclaration ... chosenDeclarations) {
        ArrayList<ExecutionVertex> tasks = new ArrayList<ExecutionVertex>();
        for (TaskDeclaration chosenDeclaration : chosenDeclarations) {
            ExecutionJobVertex jobVertex = this.chooseJobVertex(graph, chosenDeclaration.vertexIndex);
            chosenDeclaration.subtaskIndices.forEach(index -> tasks.add(jobVertex.getTaskVertices()[index]));
        }
        return tasks;
    }

    private ExecutionJobVertex chooseJobVertex(ExecutionGraph graph, int vertexIndex) {
        String name = this.vertexName(vertexIndex);
        Optional<ExecutionJobVertex> foundVertex = graph.getAllVertices().values().stream().filter(jobVertex -> jobVertex.getName().equals(name)).findFirst();
        if (!foundVertex.isPresent()) {
            throw new RuntimeException("Vertex not found with index " + vertexIndex);
        }
        return foundVertex.get();
    }

    private String vertexName(int index) {
        return "vertex_" + index;
    }

    private Set<Integer> range(int start, int end) {
        return IntStream.range(start, end).boxed().collect(Collectors.toSet());
    }

    private Set<Integer> of(Integer ... index) {
        return new HashSet<Integer>(Arrays.asList(index));
    }

    private Set<Integer> minus(Set<Integer> all, Set<Integer> toMinus) {
        return all.stream().filter(e -> !toMinus.contains(e)).collect(Collectors.toSet());
    }

    private static /* synthetic */ void lambda$createExecutionGraph$5(ExecutionGraph graph, JobVertexID jobVertexId, Integer index) {
        graph.getJobVertex(jobVertexId).getTaskVertices()[index].getCurrentExecutionAttempt().markFinished();
    }

    private static class VertexDeclaration {
        final int parallelism;
        final Set<Integer> finishedSubtaskIndices;

        public VertexDeclaration(int parallelism, Set<Integer> finishedSubtaskIndices) {
            this.parallelism = parallelism;
            this.finishedSubtaskIndices = finishedSubtaskIndices;
        }
    }

    private static class EdgeDeclaration {
        final int source;
        final int target;
        final DistributionPattern distributionPattern;

        public EdgeDeclaration(int source, int target, DistributionPattern distributionPattern) {
            this.source = source;
            this.target = target;
            this.distributionPattern = distributionPattern;
        }
    }

    private static class TaskDeclaration {
        final int vertexIndex;
        final Set<Integer> subtaskIndices;

        public TaskDeclaration(int vertexIndex, Set<Integer> subtaskIndices) {
            this.vertexIndex = vertexIndex;
            this.subtaskIndices = subtaskIndices;
        }
    }
}

