/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.raft;

import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import net.jqwik.api.AfterFailureMode;
import net.jqwik.api.ForAll;
import net.jqwik.api.Property;
import net.jqwik.api.Tag;
import net.jqwik.api.constraints.IntRange;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.protocol.ObjectSerializationCache;
import org.apache.kafka.common.protocol.Readable;
import org.apache.kafka.common.protocol.Writable;
import org.apache.kafka.common.protocol.types.Type;
import org.apache.kafka.common.utils.BufferSupplier;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.raft.Batch;
import org.apache.kafka.raft.ElectionState;
import org.apache.kafka.raft.ExpirationService;
import org.apache.kafka.raft.KafkaRaftClient;
import org.apache.kafka.raft.MockExpirationService;
import org.apache.kafka.raft.MockLog;
import org.apache.kafka.raft.MockMessageQueue;
import org.apache.kafka.raft.MockNetworkChannel;
import org.apache.kafka.raft.MockQuorumStateStore;
import org.apache.kafka.raft.NetworkChannel;
import org.apache.kafka.raft.OffsetAndEpoch;
import org.apache.kafka.raft.QuorumStateStore;
import org.apache.kafka.raft.RaftClient;
import org.apache.kafka.raft.RaftConfig;
import org.apache.kafka.raft.RaftMessage;
import org.apache.kafka.raft.RaftMessageQueue;
import org.apache.kafka.raft.RaftRequest;
import org.apache.kafka.raft.RaftResponse;
import org.apache.kafka.raft.ReplicatedCounter;
import org.apache.kafka.raft.ReplicatedLog;
import org.apache.kafka.raft.ValidOffsetAndEpoch;
import org.apache.kafka.raft.internals.BatchMemoryPool;
import org.apache.kafka.server.common.serialization.RecordSerde;
import org.apache.kafka.snapshot.RawSnapshotReader;
import org.apache.kafka.snapshot.RecordsSnapshotReader;
import org.junit.jupiter.api.Assertions;

@Tag(value="integration")
public class RaftEventSimulationTest {
    private static final TopicPartition METADATA_PARTITION = new TopicPartition("__cluster_metadata", 0);
    private static final int ELECTION_TIMEOUT_MS = 1000;
    private static final int ELECTION_JITTER_MS = 100;
    private static final int FETCH_TIMEOUT_MS = 3000;
    private static final int RETRY_BACKOFF_MS = 50;
    private static final int REQUEST_TIMEOUT_MS = 3000;
    private static final int FETCH_MAX_WAIT_MS = 100;
    private static final int LINGER_MS = 0;

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canElectInitialLeader(@ForAll int seed, @ForAll @IntRange(min=1, max=5) @IntRange(min=1, max=5) int numVoters, @ForAll @IntRange(min=0, max=5) @IntRange(min=0, max=5) int numObservers) {
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        MessageRouter router = new MessageRouter(cluster);
        EventScheduler scheduler = this.schedulerWithDefaultInvariants(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 1);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(cluster::hasConsistentLeader);
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(10L));
    }

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canElectNewLeaderAfterOldLeaderFailure(@ForAll int seed, @ForAll @IntRange(min=3, max=5) @IntRange(min=3, max=5) int numVoters, @ForAll @IntRange(min=0, max=5) @IntRange(min=0, max=5) int numObservers, @ForAll boolean isGracefulShutdown) {
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        MessageRouter router = new MessageRouter(cluster);
        EventScheduler scheduler = this.schedulerWithDefaultInvariants(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 1);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(cluster::hasConsistentLeader);
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10L));
        int leaderId = cluster.latestLeader().orElseThrow(() -> new AssertionError((Object)"Failed to find current leader"));
        if (isGracefulShutdown) {
            cluster.shutdown(leaderId);
        } else {
            cluster.kill(leaderId);
        }
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(20L));
        long highWatermark = cluster.maxHighWatermarkReached();
        cluster.start(leaderId);
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(highWatermark + 10L));
    }

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canRecoverAfterAllNodesKilled(@ForAll int seed, @ForAll @IntRange(min=1, max=5) @IntRange(min=1, max=5) int numVoters, @ForAll @IntRange(min=0, max=5) @IntRange(min=0, max=5) int numObservers) {
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        MessageRouter router = new MessageRouter(cluster);
        EventScheduler scheduler = this.schedulerWithDefaultInvariants(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 1);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(cluster::hasConsistentLeader);
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10L));
        long highWatermark = cluster.maxHighWatermarkReached();
        cluster.killAll();
        Iterator<Integer> nodeIdsIterator = cluster.nodes().iterator();
        for (int i = 0; i < cluster.majoritySize(); ++i) {
            Integer nodeId = nodeIdsIterator.next();
            cluster.start(nodeId);
        }
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(highWatermark + 10L));
    }

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canElectNewLeaderAfterOldLeaderPartitionedAway(@ForAll int seed, @ForAll @IntRange(min=3, max=5) @IntRange(min=3, max=5) int numVoters, @ForAll @IntRange(min=0, max=5) @IntRange(min=0, max=5) int numObservers) {
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        MessageRouter router = new MessageRouter(cluster);
        EventScheduler scheduler = this.schedulerWithDefaultInvariants(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 2);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(cluster::hasConsistentLeader);
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10L));
        int leaderId = cluster.latestLeader().orElseThrow(() -> new AssertionError((Object)"Failed to find current leader"));
        router.filter(leaderId, new DropAllTraffic());
        HashSet<Integer> nonPartitionedNodes = new HashSet<Integer>(cluster.nodes());
        nonPartitionedNodes.remove(leaderId);
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(20L, nonPartitionedNodes));
    }

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canMakeProgressIfMajorityIsReachable(@ForAll int seed, @ForAll @IntRange(min=0, max=3) @IntRange(min=0, max=3) int numObservers) {
        int numVoters = 5;
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        MessageRouter router = new MessageRouter(cluster);
        EventScheduler scheduler = this.schedulerWithDefaultInvariants(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 2);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(cluster::hasConsistentLeader);
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10L));
        router.filter(0, new DropOutboundRequestsFrom(Utils.mkSet((Object[])new Integer[]{2, 3, 4})));
        router.filter(1, new DropOutboundRequestsFrom(Utils.mkSet((Object[])new Integer[]{2, 3, 4})));
        router.filter(2, new DropOutboundRequestsFrom(Utils.mkSet((Object[])new Integer[]{0, 1})));
        router.filter(3, new DropOutboundRequestsFrom(Utils.mkSet((Object[])new Integer[]{0, 1})));
        router.filter(4, new DropOutboundRequestsFrom(Utils.mkSet((Object[])new Integer[]{0, 1})));
        long partitionLogEndOffset = cluster.maxLogEndOffset();
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(2L * partitionLogEndOffset));
        long minorityHighWatermark = cluster.maxHighWatermarkReached(Utils.mkSet((Object[])new Integer[]{0, 1}));
        long majorityHighWatermark = cluster.maxHighWatermarkReached(Utils.mkSet((Object[])new Integer[]{2, 3, 4}));
        Assertions.assertTrue((majorityHighWatermark > minorityHighWatermark ? 1 : 0) != 0, (String)String.format("majorityHighWatermark = %s, minorityHighWatermark = %s", majorityHighWatermark, minorityHighWatermark));
        router.filter(0, new PermitAllTraffic());
        router.filter(1, new PermitAllTraffic());
        router.filter(2, new PermitAllTraffic());
        router.filter(3, new PermitAllTraffic());
        router.filter(4, new PermitAllTraffic());
        long restoredLogEndOffset = cluster.maxLogEndOffset();
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(2L * restoredLogEndOffset));
    }

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canMakeProgressAfterBackToBackLeaderFailures(@ForAll int seed, @ForAll @IntRange(min=3, max=5) @IntRange(min=3, max=5) int numVoters, @ForAll @IntRange(min=0, max=5) @IntRange(min=0, max=5) int numObservers) {
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        MessageRouter router = new MessageRouter(cluster);
        EventScheduler scheduler = this.schedulerWithDefaultInvariants(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 5);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(cluster::hasConsistentLeader);
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10L));
        int leaderId = cluster.latestLeader().getAsInt();
        router.filter(leaderId, new DropAllTraffic());
        scheduler.runUntil(() -> cluster.latestLeader().isPresent() && cluster.latestLeader().getAsInt() != leaderId);
        int newLeaderId = cluster.latestLeader().getAsInt();
        router.filter(leaderId, new PermitAllTraffic());
        router.filter(newLeaderId, new DropAllTraffic());
        long targetHighWatermark = cluster.maxHighWatermarkReached() + 10L;
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(targetHighWatermark));
    }

    @Property(tries=100, afterFailure=AfterFailureMode.SAMPLE_ONLY)
    void canRecoverFromSingleNodeCommittedDataLoss(@ForAll int seed, @ForAll @IntRange(min=3, max=5) @IntRange(min=3, max=5) int numVoters, @ForAll @IntRange(min=0, max=2) @IntRange(min=0, max=2) int numObservers) {
        Random random = new Random(seed);
        Cluster cluster = new Cluster(numVoters, numObservers, random);
        EventScheduler scheduler = new EventScheduler(cluster.random, (Time)cluster.time);
        scheduler.addInvariant(new MonotonicHighWatermark(cluster));
        scheduler.addInvariant(new SingleLeader(cluster));
        scheduler.addValidation(new ConsistentCommittedData(cluster));
        MessageRouter router = new MessageRouter(cluster);
        cluster.startAll();
        this.schedulePolling(scheduler, cluster, 3, 5);
        scheduler.schedule(router::deliverAll, 0, 2, 5);
        scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3);
        scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10L));
        RaftNode node = cluster.randomRunning().orElseThrow(() -> new AssertionError((Object)"Failed to find running node"));
        cluster.killAndDeletePersistentState(node.nodeId);
        scheduler.runUntil(() -> !cluster.hasLeader(node.nodeId) && cluster.hasConsistentLeader());
        long highWatermarkBeforeRestart = cluster.maxHighWatermarkReached();
        cluster.start(node.nodeId);
        scheduler.runUntil(() -> cluster.allReachedHighWatermark(highWatermarkBeforeRestart + 10L));
    }

    private EventScheduler schedulerWithDefaultInvariants(Cluster cluster) {
        EventScheduler scheduler = new EventScheduler(cluster.random, (Time)cluster.time);
        scheduler.addInvariant(new MonotonicHighWatermark(cluster));
        scheduler.addInvariant(new MonotonicEpoch(cluster));
        scheduler.addInvariant(new MajorityReachedHighWatermark(cluster));
        scheduler.addInvariant(new SingleLeader(cluster));
        scheduler.addInvariant(new SnapshotAtLogStart(cluster));
        scheduler.addInvariant(new LeaderNeverLoadSnapshot(cluster));
        scheduler.addValidation(new ConsistentCommittedData(cluster));
        return scheduler;
    }

    private void schedulePolling(EventScheduler scheduler, Cluster cluster, int pollIntervalMs, int pollJitterMs) {
        int delayMs = 0;
        for (int nodeId : cluster.nodes()) {
            scheduler.schedule(() -> cluster.pollIfRunning(nodeId), delayMs, pollIntervalMs, pollJitterMs);
            ++delayMs;
        }
    }

    private static class IntSerde
    implements RecordSerde<Integer> {
        private IntSerde() {
        }

        public int recordSize(Integer data, ObjectSerializationCache serializationCache) {
            return Type.INT32.sizeOf((Object)data);
        }

        public void write(Integer data, ObjectSerializationCache serializationCache, Writable out) {
            out.writeInt(data.intValue());
        }

        public Integer read(Readable input, int size) {
            return input.readInt();
        }
    }

    private static class MessageRouter {
        final Map<Integer, InflightRequest> inflight = new HashMap<Integer, InflightRequest>();
        final Map<Integer, NetworkFilter> filters = new HashMap<Integer, NetworkFilter>();
        final Cluster cluster;

        private MessageRouter(Cluster cluster) {
            this.cluster = cluster;
            for (int nodeId : cluster.nodes.keySet()) {
                this.filters.put(nodeId, new PermitAllTraffic());
            }
        }

        void deliver(int senderId, RaftRequest.Outbound outbound) {
            if (!this.filters.get(senderId).acceptOutbound((RaftMessage)outbound)) {
                return;
            }
            int correlationId = outbound.correlationId();
            int destinationId = outbound.destinationId();
            RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(), this.cluster.time.milliseconds());
            if (!this.filters.get(destinationId).acceptInbound((RaftMessage)inbound)) {
                return;
            }
            this.cluster.nodeIfRunning(destinationId).ifPresent(node -> {
                this.inflight.put(correlationId, new InflightRequest(correlationId, senderId, destinationId));
                inbound.completion.whenComplete((response, exception) -> {
                    if (response != null && this.filters.get(destinationId).acceptOutbound((RaftMessage)response)) {
                        this.deliver(destinationId, (RaftResponse.Outbound)response);
                    }
                });
                node.client.handle(inbound);
            });
        }

        void deliver(int senderId, RaftResponse.Outbound outbound) {
            int correlationId = outbound.correlationId();
            RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId);
            InflightRequest inflightRequest = this.inflight.remove(correlationId);
            if (!this.filters.get(inflightRequest.sourceId).acceptInbound((RaftMessage)inbound)) {
                return;
            }
            this.cluster.nodeIfRunning(inflightRequest.sourceId).ifPresent(node -> node.channel.mockReceive(inbound));
        }

        void filter(int nodeId, NetworkFilter filter) {
            this.filters.put(nodeId, filter);
        }

        void deliverTo(RaftNode node) {
            node.channel.drainSendQueue().forEach(msg -> this.deliver(node.nodeId, (RaftRequest.Outbound)msg));
        }

        void deliverAll() {
            for (RaftNode node : this.cluster.running()) {
                this.deliverTo(node);
            }
        }
    }

    private static class ConsistentCommittedData
    implements Validation {
        final Cluster cluster;
        final Map<Long, Integer> committedSequenceNumbers = new HashMap<Long, Integer>();

        private ConsistentCommittedData(Cluster cluster) {
            this.cluster = cluster;
        }

        private int parseSequenceNumber(ByteBuffer value) {
            return (Integer)Type.INT32.read(value);
        }

        private void assertCommittedData(RaftNode node) {
            int nodeId = node.nodeId;
            KafkaRaftClient<Integer> manager = node.client;
            MockLog log = node.log;
            OptionalLong highWatermark = manager.highWatermark();
            if (!highWatermark.isPresent()) {
                return;
            }
            AtomicLong startOffset = new AtomicLong(0L);
            log.earliestSnapshotId().ifPresent(snapshotId -> {
                Assertions.assertTrue((snapshotId.offset() <= highWatermark.getAsLong() ? 1 : 0) != 0);
                startOffset.set(snapshotId.offset());
                try (RecordsSnapshotReader snapshot = RecordsSnapshotReader.of((RawSnapshotReader)log.readSnapshot((OffsetAndEpoch)snapshotId).get(), node.intSerde, (BufferSupplier)BufferSupplier.create(), (int)Integer.MAX_VALUE, (boolean)true);){
                    Assertions.assertTrue((boolean)snapshot.hasNext());
                    Batch batch = (Batch)snapshot.next();
                    Assertions.assertFalse((boolean)snapshot.hasNext());
                    Assertions.assertEquals((int)1, (int)batch.records().size());
                    long offset = snapshotId.offset() - 1L;
                    int sequence = (Integer)batch.records().get(0);
                    this.committedSequenceNumbers.putIfAbsent(offset, sequence);
                    Assertions.assertEquals((Integer)this.committedSequenceNumbers.get(offset), (int)sequence, (String)String.format("Committed sequence at offset %s changed on node %s", offset, nodeId));
                }
            });
            for (MockLog.LogBatch batch : log.readBatches(startOffset.get(), highWatermark)) {
                if (batch.isControlBatch) continue;
                for (MockLog.LogEntry entry : batch.entries) {
                    long offset = entry.offset;
                    Assertions.assertTrue((offset < highWatermark.getAsLong() ? 1 : 0) != 0);
                    int sequence = this.parseSequenceNumber(entry.record.value().duplicate());
                    this.committedSequenceNumbers.putIfAbsent(offset, sequence);
                    int committedSequence = this.committedSequenceNumbers.get(offset);
                    Assertions.assertEquals((int)committedSequence, (int)sequence, (String)("Committed sequence at offset " + offset + " changed on node " + nodeId));
                }
            }
        }

        @Override
        public void validate() {
            this.cluster.forAllRunning(this::assertCommittedData);
        }
    }

    private static class LeaderNeverLoadSnapshot
    implements Invariant {
        final Cluster cluster;

        private LeaderNeverLoadSnapshot(Cluster cluster) {
            this.cluster = cluster;
        }

        @Override
        public void verify() {
            for (RaftNode raftNode : this.cluster.running()) {
                if (!raftNode.counter.isWritable()) continue;
                Assertions.assertEquals((int)0, (int)raftNode.counter.handleLoadSnapshotCalls());
            }
        }
    }

    private static class SnapshotAtLogStart
    implements Invariant {
        final Cluster cluster;

        private SnapshotAtLogStart(Cluster cluster) {
            this.cluster = cluster;
        }

        @Override
        public void verify() {
            for (Map.Entry<Integer, PersistentState> nodeEntry : this.cluster.nodes.entrySet()) {
                int nodeId = nodeEntry.getKey();
                MockLog log = nodeEntry.getValue().log;
                log.earliestSnapshotId().ifPresent(earliestSnapshotId -> {
                    long logStartOffset = log.startOffset();
                    ValidOffsetAndEpoch validateOffsetAndEpoch = log.validateOffsetAndEpoch(earliestSnapshotId.offset(), earliestSnapshotId.epoch());
                    Assertions.assertTrue((logStartOffset <= earliestSnapshotId.offset() ? 1 : 0) != 0, () -> String.format("invalid log start offset (%s) and snapshotId offset (%s): nodeId = %s", logStartOffset, earliestSnapshotId.offset(), nodeId));
                    Assertions.assertEquals((Object)ValidOffsetAndEpoch.valid((OffsetAndEpoch)earliestSnapshotId), (Object)validateOffsetAndEpoch, () -> String.format("invalid leader epoch cache: nodeId = %s", nodeId));
                    if (logStartOffset > 0L) {
                        Assertions.assertEquals((long)logStartOffset, (long)earliestSnapshotId.offset(), () -> String.format("mising snapshot at log start offset: nodeId = %s", nodeId));
                    }
                });
            }
        }
    }

    private static class MonotonicHighWatermark
    implements Invariant {
        final Cluster cluster;
        long highWatermark = 0L;

        private MonotonicHighWatermark(Cluster cluster) {
            this.cluster = cluster;
        }

        @Override
        public void verify() {
            OptionalLong leaderHighWatermark = this.cluster.leaderHighWatermark();
            leaderHighWatermark.ifPresent(newHighWatermark -> {
                long oldHighWatermark = this.highWatermark;
                this.highWatermark = newHighWatermark;
                if (newHighWatermark < oldHighWatermark) {
                    Assertions.fail((String)("Non-monotonic update of high watermark detected: " + oldHighWatermark + " -> " + newHighWatermark));
                }
            });
        }
    }

    private static class SingleLeader
    implements Invariant {
        final Cluster cluster;
        int epoch = 0;
        OptionalInt leaderId = OptionalInt.empty();

        private SingleLeader(Cluster cluster) {
            this.cluster = cluster;
        }

        @Override
        public void verify() {
            for (Map.Entry<Integer, PersistentState> nodeEntry : this.cluster.nodes.entrySet()) {
                PersistentState state = nodeEntry.getValue();
                ElectionState electionState = state.store.readElectionState();
                if (electionState == null || electionState.epoch < this.epoch || !electionState.hasLeader()) continue;
                if (this.epoch == electionState.epoch && this.leaderId.isPresent()) {
                    Assertions.assertEquals((int)this.leaderId.getAsInt(), (int)electionState.leaderId());
                    continue;
                }
                this.epoch = electionState.epoch;
                this.leaderId = OptionalInt.of(electionState.leaderId());
            }
        }
    }

    private static class MajorityReachedHighWatermark
    implements Invariant {
        final Cluster cluster;

        private MajorityReachedHighWatermark(Cluster cluster) {
            this.cluster = cluster;
        }

        @Override
        public void verify() {
            this.cluster.leaderHighWatermark().ifPresent(highWatermark -> {
                long numReachedHighWatermark = this.cluster.nodes.entrySet().stream().filter(entry -> this.cluster.voters.contains(entry.getKey())).filter(entry -> ((PersistentState)entry.getValue()).log.endOffset().offset >= highWatermark).count();
                Assertions.assertTrue((numReachedHighWatermark >= (long)this.cluster.majoritySize() ? 1 : 0) != 0, (String)"Insufficient nodes have reached current high watermark");
            });
        }
    }

    private static class MonotonicEpoch
    implements Invariant {
        final Cluster cluster;
        final Map<Integer, Integer> nodeEpochs = new HashMap<Integer, Integer>();

        private MonotonicEpoch(Cluster cluster) {
            this.cluster = cluster;
            for (Map.Entry<Integer, PersistentState> nodeStateEntry : cluster.nodes.entrySet()) {
                Integer nodeId = nodeStateEntry.getKey();
                this.nodeEpochs.put(nodeId, 0);
            }
        }

        @Override
        public void verify() {
            for (Map.Entry<Integer, PersistentState> nodeStateEntry : this.cluster.nodes.entrySet()) {
                Integer nodeId = nodeStateEntry.getKey();
                PersistentState state = nodeStateEntry.getValue();
                Integer oldEpoch = this.nodeEpochs.get(nodeId);
                ElectionState electionState = state.store.readElectionState();
                if (electionState == null) continue;
                Integer newEpoch = electionState.epoch;
                if (oldEpoch > newEpoch) {
                    Assertions.fail((String)("Non-monotonic update of epoch detected on node " + nodeId + ": " + oldEpoch + " -> " + newEpoch));
                }
                this.cluster.ifRunning(nodeId, nodeState -> Assertions.assertEquals((int)newEpoch, (int)nodeState.client.quorum().epoch()));
                this.nodeEpochs.put(nodeId, newEpoch);
            }
        }
    }

    private static class DropOutboundRequestsFrom
    implements NetworkFilter {
        private final Set<Integer> unreachable;

        private DropOutboundRequestsFrom(Set<Integer> unreachable) {
            this.unreachable = unreachable;
        }

        @Override
        public boolean acceptInbound(RaftMessage message) {
            return true;
        }

        @Override
        public boolean acceptOutbound(RaftMessage message) {
            if (message instanceof RaftRequest.Outbound) {
                RaftRequest.Outbound request = (RaftRequest.Outbound)message;
                return !this.unreachable.contains(request.destinationId());
            }
            return true;
        }
    }

    private static class DropAllTraffic
    implements NetworkFilter {
        private DropAllTraffic() {
        }

        @Override
        public boolean acceptInbound(RaftMessage message) {
            return false;
        }

        @Override
        public boolean acceptOutbound(RaftMessage message) {
            return false;
        }
    }

    private static class PermitAllTraffic
    implements NetworkFilter {
        private PermitAllTraffic() {
        }

        @Override
        public boolean acceptInbound(RaftMessage message) {
            return true;
        }

        @Override
        public boolean acceptOutbound(RaftMessage message) {
            return true;
        }
    }

    private static interface NetworkFilter {
        public boolean acceptInbound(RaftMessage var1);

        public boolean acceptOutbound(RaftMessage var1);
    }

    private static class InflightRequest {
        final int correlationId;
        final int sourceId;
        final int destinationId;

        private InflightRequest(int correlationId, int sourceId, int destinationId) {
            this.correlationId = correlationId;
            this.sourceId = sourceId;
            this.destinationId = destinationId;
        }
    }

    private static class RaftNode {
        final int nodeId;
        final KafkaRaftClient<Integer> client;
        final MockLog log;
        final MockNetworkChannel channel;
        final MockMessageQueue messageQueue;
        final MockQuorumStateStore store;
        final LogContext logContext;
        final ReplicatedCounter counter;
        final Time time;
        final Random random;
        final RecordSerde<Integer> intSerde;

        private RaftNode(int nodeId, KafkaRaftClient<Integer> client, MockLog log, MockNetworkChannel channel, MockMessageQueue messageQueue, MockQuorumStateStore store, LogContext logContext, Time time, Random random, RecordSerde<Integer> intSerde) {
            this.nodeId = nodeId;
            this.client = client;
            this.log = log;
            this.channel = channel;
            this.messageQueue = messageQueue;
            this.store = store;
            this.logContext = logContext;
            this.time = time;
            this.random = random;
            this.counter = new ReplicatedCounter(nodeId, client, logContext);
            this.intSerde = intSerde;
        }

        void initialize() {
            this.client.register((RaftClient.Listener)this.counter);
            this.client.initialize();
        }

        void poll() {
            try {
                do {
                    this.client.poll();
                } while (this.client.isRunning() && !this.messageQueue.isEmpty());
            }
            catch (Exception e) {
                throw new RuntimeException("Uncaught exception during poll of node " + this.nodeId, e);
            }
        }

        long highWatermark() {
            return this.client.quorum().highWatermark().map(hw -> hw.offset).orElse(0L);
        }

        long logEndOffset() {
            return this.log.endOffset().offset;
        }

        public String toString() {
            return String.format("Node(id=%s, hw=%s, logEndOffset=%s)", this.nodeId, this.highWatermark(), this.logEndOffset());
        }
    }

    private static class Cluster {
        final Random random;
        final AtomicInteger correlationIdCounter = new AtomicInteger();
        final MockTime time = new MockTime();
        final Uuid clusterId = Uuid.randomUuid();
        final Set<Integer> voters = new HashSet<Integer>();
        final Map<Integer, PersistentState> nodes = new HashMap<Integer, PersistentState>();
        final Map<Integer, RaftNode> running = new HashMap<Integer, RaftNode>();

        private Cluster(int numVoters, int numObservers, Random random) {
            int nodeId;
            this.random = random;
            for (nodeId = 0; nodeId < numVoters; ++nodeId) {
                this.voters.add(nodeId);
                this.nodes.put(nodeId, new PersistentState(nodeId));
            }
            while (nodeId < numVoters + numObservers) {
                this.nodes.put(nodeId, new PersistentState(nodeId));
                ++nodeId;
            }
        }

        Set<Integer> nodes() {
            return this.nodes.keySet();
        }

        int majoritySize() {
            return this.voters.size() / 2 + 1;
        }

        long maxLogEndOffset() {
            return this.running.values().stream().mapToLong(RaftNode::logEndOffset).max().orElse(0L);
        }

        OptionalLong leaderHighWatermark() {
            Optional<RaftNode> leaderWithMaxEpoch = this.running.values().stream().filter(node -> node.client.quorum().isLeader()).max((node1, node2) -> Integer.compare(node2.client.quorum().epoch(), node1.client.quorum().epoch()));
            if (leaderWithMaxEpoch.isPresent()) {
                return leaderWithMaxEpoch.get().client.highWatermark();
            }
            return OptionalLong.empty();
        }

        boolean anyReachedHighWatermark(long offset) {
            return this.running.values().stream().anyMatch(node -> node.highWatermark() > offset);
        }

        long maxHighWatermarkReached() {
            return this.running.values().stream().mapToLong(RaftNode::highWatermark).max().orElse(0L);
        }

        long maxHighWatermarkReached(Set<Integer> nodeIds) {
            return this.running.values().stream().filter(node -> nodeIds.contains(node.nodeId)).mapToLong(RaftNode::highWatermark).max().orElse(0L);
        }

        boolean allReachedHighWatermark(long offset, Set<Integer> nodeIds) {
            return nodeIds.stream().allMatch(nodeId -> this.running.get(nodeId).highWatermark() >= offset);
        }

        boolean allReachedHighWatermark(long offset) {
            return this.running.values().stream().allMatch(node -> node.highWatermark() >= offset);
        }

        boolean hasLeader(int nodeId) {
            OptionalInt latestLeader = this.latestLeader();
            return latestLeader.isPresent() && latestLeader.getAsInt() == nodeId;
        }

        OptionalInt latestLeader() {
            OptionalInt latestLeader = OptionalInt.empty();
            int latestEpoch = 0;
            for (RaftNode node : this.running.values()) {
                if (node.client.quorum().epoch() > latestEpoch) {
                    latestLeader = node.client.quorum().leaderId();
                    latestEpoch = node.client.quorum().epoch();
                    continue;
                }
                if (node.client.quorum().epoch() != latestEpoch || !node.client.quorum().leaderId().isPresent()) continue;
                latestLeader = node.client.quorum().leaderId();
            }
            return latestLeader;
        }

        boolean hasConsistentLeader() {
            Iterator<RaftNode> iter = this.running.values().iterator();
            if (!iter.hasNext()) {
                return false;
            }
            RaftNode first = iter.next();
            ElectionState election = first.store.readElectionState();
            if (!election.hasLeader()) {
                return false;
            }
            while (iter.hasNext()) {
                RaftNode next = iter.next();
                if (election.equals((Object)next.store.readElectionState())) continue;
                return false;
            }
            return true;
        }

        void killAll() {
            this.running.clear();
        }

        void kill(int nodeId) {
            this.running.remove(nodeId);
        }

        void shutdown(int nodeId) {
            RaftNode node = this.running.get(nodeId);
            if (node == null) {
                throw new IllegalStateException("Attempt to shutdown a node which is not currently running");
            }
            node.client.shutdown(500).whenComplete((res, exception) -> this.kill(nodeId));
        }

        void pollIfRunning(int nodeId) {
            this.ifRunning(nodeId, RaftNode::poll);
        }

        Optional<RaftNode> nodeIfRunning(int nodeId) {
            return Optional.ofNullable(this.running.get(nodeId));
        }

        Collection<RaftNode> running() {
            return this.running.values();
        }

        void ifRunning(int nodeId, Consumer<RaftNode> action) {
            this.nodeIfRunning(nodeId).ifPresent(action);
        }

        Optional<RaftNode> randomRunning() {
            ArrayList<RaftNode> nodes = new ArrayList<RaftNode>(this.running.values());
            if (nodes.isEmpty()) {
                return Optional.empty();
            }
            return Optional.of((RaftNode)nodes.get(this.random.nextInt(nodes.size())));
        }

        void withCurrentLeader(Consumer<RaftNode> action) {
            for (RaftNode node : this.running.values()) {
                if (!node.client.quorum().isLeader()) continue;
                action.accept(node);
            }
        }

        void forAllRunning(Consumer<RaftNode> action) {
            this.running.values().forEach(action);
        }

        void startAll() {
            if (!this.running.isEmpty()) {
                throw new IllegalStateException("Some nodes are already started");
            }
            for (int voterId : this.nodes.keySet()) {
                this.start(voterId);
            }
        }

        void killAndDeletePersistentState(int nodeId) {
            this.kill(nodeId);
            this.nodes.put(nodeId, new PersistentState(nodeId));
        }

        private static RaftConfig.AddressSpec nodeAddress(int id) {
            return new RaftConfig.InetAddressSpec(new InetSocketAddress("localhost", 9990 + id));
        }

        void start(int nodeId) {
            LogContext logContext = new LogContext("[Node " + nodeId + "] ");
            PersistentState persistentState = this.nodes.get(nodeId);
            MockNetworkChannel channel = new MockNetworkChannel(this.correlationIdCounter, this.voters);
            MockMessageQueue messageQueue = new MockMessageQueue();
            Map<Integer, RaftConfig.AddressSpec> voterAddressMap = this.voters.stream().collect(Collectors.toMap(id -> id, Cluster::nodeAddress));
            RaftConfig raftConfig = new RaftConfig(voterAddressMap, 3000, 50, 1000, 100, 3000, 0);
            Metrics metrics = new Metrics((Time)this.time);
            persistentState.log.reopen();
            IntSerde serde = new IntSerde();
            BatchMemoryPool memoryPool = new BatchMemoryPool(2, 0x800000);
            KafkaRaftClient client = new KafkaRaftClient((RecordSerde)serde, (NetworkChannel)channel, (RaftMessageQueue)messageQueue, (ReplicatedLog)persistentState.log, (QuorumStateStore)persistentState.store, (MemoryPool)memoryPool, (Time)this.time, metrics, (ExpirationService)new MockExpirationService(this.time), 100, this.clusterId.toString(), OptionalInt.of(nodeId), logContext, this.random, raftConfig);
            RaftNode node = new RaftNode(nodeId, client, persistentState.log, channel, messageQueue, persistentState.store, logContext, (Time)this.time, this.random, serde);
            node.initialize();
            this.running.put(nodeId, node);
        }
    }

    private static class PersistentState {
        final MockQuorumStateStore store = new MockQuorumStateStore();
        final MockLog log;

        PersistentState(int nodeId) {
            this.log = new MockLog(METADATA_PARTITION, Uuid.METADATA_TOPIC_ID, new LogContext(String.format("[Node %s] ", nodeId)));
        }
    }

    private static class EventScheduler {
        private static final int MAX_ITERATIONS = 500000;
        final AtomicInteger eventIdGenerator = new AtomicInteger(0);
        final PriorityQueue<Event> queue = new PriorityQueue();
        final Random random;
        final Time time;
        final List<Invariant> invariants = new ArrayList<Invariant>();
        final List<Validation> validations = new ArrayList<Validation>();

        private EventScheduler(Random random, Time time) {
            this.random = random;
            this.time = time;
        }

        private void addInvariant(Invariant invariant) {
            this.invariants.add(invariant);
        }

        private void addValidation(Validation validation) {
            this.validations.add(validation);
        }

        void schedule(Runnable action, int delayMs, int periodMs, int jitterMs) {
            long initialDeadlineMs = this.time.milliseconds() + (long)delayMs;
            int eventId = this.eventIdGenerator.incrementAndGet();
            PeriodicEvent event = new PeriodicEvent(action, eventId, this.random, initialDeadlineMs, periodMs, jitterMs);
            this.queue.offer(event);
        }

        void runUntil(Supplier<Boolean> exitCondition) {
            for (int iteration = 0; iteration < 500000 && !exitCondition.get().booleanValue(); ++iteration) {
                if (this.queue.isEmpty()) {
                    throw new IllegalStateException("Event queue exhausted before condition was satisfied");
                }
                Event event = this.queue.poll();
                long delayMs = Math.max(event.deadlineMs - this.time.milliseconds(), 0L);
                this.time.sleep(delayMs);
                event.execute(this);
                this.invariants.forEach(Invariant::verify);
            }
            Assertions.assertTrue((boolean)exitCondition.get(), (String)"Simulation condition was not satisfied after 500000 iterations");
            this.validations.forEach(Validation::validate);
        }
    }

    private static interface Validation {
        public void validate();
    }

    private static interface Invariant {
        public void verify();
    }

    private static class SequentialAppendAction
    implements Runnable {
        final Cluster cluster;

        private SequentialAppendAction(Cluster cluster) {
            this.cluster = cluster;
        }

        @Override
        public void run() {
            this.cluster.withCurrentLeader(node -> {
                if (!node.client.isShuttingDown() && node.counter.isWritable()) {
                    node.counter.increment();
                }
            });
        }
    }

    private static class PeriodicEvent
    extends Event {
        final Random random;
        final int periodMs;
        final int jitterMs;

        protected PeriodicEvent(Runnable action, int eventId, Random random, long deadlineMs, int periodMs, int jitterMs) {
            super(action, eventId, deadlineMs);
            this.random = random;
            this.periodMs = periodMs;
            this.jitterMs = jitterMs;
        }

        @Override
        void execute(EventScheduler scheduler) {
            super.execute(scheduler);
            int nextExecDelayMs = this.periodMs + (this.jitterMs == 0 ? 0 : this.random.nextInt(this.jitterMs));
            scheduler.schedule(this.action, nextExecDelayMs, this.periodMs, this.jitterMs);
        }
    }

    private static abstract class Event
    implements Comparable<Event> {
        final int eventId;
        final long deadlineMs;
        final Runnable action;

        protected Event(Runnable action, int eventId, long deadlineMs) {
            this.action = action;
            this.eventId = eventId;
            this.deadlineMs = deadlineMs;
        }

        void execute(EventScheduler scheduler) {
            this.action.run();
        }

        @Override
        public int compareTo(Event other) {
            int compare = Long.compare(this.deadlineMs, other.deadlineMs);
            if (compare != 0) {
                return compare;
            }
            return Integer.compare(this.eventId, other.eventId);
        }
    }
}

