package org.apache.drill.exec.server;

import ch.qos.logback.classic.Level;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Iterator;
import org.apache.commons.math3.util.Pair;
import org.apache.drill.SingleRowListener;
import org.apache.drill.common.DrillAutoCloseables;
import org.apache.drill.common.concurrent.ExtendedLatch;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.exec.ZookeeperTestUtil;
import org.apache.drill.exec.client.DrillClient;
import org.apache.drill.exec.memory.BufferAllocator;
import org.apache.drill.exec.memory.RootAllocatorFactory;
import org.apache.drill.exec.physical.impl.ScreenCreator;
import org.apache.drill.exec.physical.impl.SingleSenderCreator;
import org.apache.drill.exec.physical.impl.mergereceiver.MergingRecordBatch;
import org.apache.drill.exec.physical.impl.partitionsender.PartitionerDecorator;
import org.apache.drill.exec.physical.impl.unorderedreceiver.UnorderedReceiverBatch;
import org.apache.drill.exec.physical.impl.xsort.ExternalSortBatch;
import org.apache.drill.exec.planner.physical.PlannerSettings;
import org.apache.drill.exec.planner.sql.DrillSqlWorker;
import org.apache.drill.exec.proto.GeneralRPCProtos;
import org.apache.drill.exec.proto.UserBitShared;
import org.apache.drill.exec.record.BatchSchema;
import org.apache.drill.exec.record.RecordBatchLoader;
import org.apache.drill.exec.record.VectorWrapper;
import org.apache.drill.exec.rpc.ConnectionThrottle;
import org.apache.drill.exec.rpc.DrillRpcFuture;
import org.apache.drill.exec.rpc.RpcException;
import org.apache.drill.exec.rpc.user.QueryDataBatch;
import org.apache.drill.exec.rpc.user.UserResultsListener;
import org.apache.drill.exec.store.pojo.PojoRecordReader;
import org.apache.drill.exec.testing.Controls;
import org.apache.drill.exec.testing.ControlsInjectionUtil;
import org.apache.drill.exec.testing.ExecutionControlsInjector;
import org.apache.drill.exec.util.Pointer;
import org.apache.drill.exec.work.WorkManager;
import org.apache.drill.exec.work.foreman.Foreman;
import org.apache.drill.exec.work.foreman.ForemanException;
import org.apache.drill.exec.work.foreman.ForemanSetupException;
import org.apache.drill.exec.work.foreman.FragmentsRunner;
import org.apache.drill.exec.work.foreman.QueryStateProcessor;
import org.apache.drill.exec.work.fragment.FragmentExecutor;
import org.apache.drill.test.BaseTestQuery;
import org.apache.drill.test.ClusterFixture;
import org.apache.drill.test.ClusterTest;
import org.apache.drill.test.LogFixture;
import org.apache.drill.test.QueryTestUtil;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Tags;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.TestInstantiationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Tags({@Tag("slow-test"), @Tag("flaky-test")})
/* loaded from: input_file:org/apache/drill/exec/server/TestDrillbitResilience.class */
public class TestDrillbitResilience extends ClusterTest {
    protected static LogFixture logFixture;
    private static final int NUM_RUNS = 3;
    private static final int PROBLEMATIC_TEST_NUM_RUNS = 3;
    private static final int TIMEOUT = 15;
    private static final String TEST_QUERY = "select * from sys.memory";
    private static final String DRILLBIT_ALPHA = "alpha";
    private static final String DRILLBIT_BETA = "beta";
    private static final String DRILLBIT_GAMMA = "gamma";
    private static final Logger logger = LoggerFactory.getLogger(TestDrillbitResilience.class);
    private static final Level CURRENT_LOG_LEVEL = Level.INFO;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/server/TestDrillbitResilience$CancellingThread.class */
    public static class CancellingThread extends Thread {
        private final UserBitShared.QueryId queryId;
        private final Pointer<Exception> ex;
        private final ExtendedLatch latch;

        public CancellingThread(UserBitShared.QueryId queryId, Pointer<Exception> pointer, ExtendedLatch extendedLatch) {
            this.queryId = queryId;
            this.ex = pointer;
            this.latch = extendedLatch;
            TestDrillbitResilience.logger.debug("Cancelling thread created");
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
                TestDrillbitResilience.logger.debug("Cancelling thread interrupted. Ignore it");
            }
            TestDrillbitResilience.logger.debug("Cancelling {} query started", this.queryId);
            DrillRpcFuture cancelQuery = TestDrillbitResilience.client.client().cancelQuery(this.queryId);
            TestDrillbitResilience.logger.debug("Check future: {}", cancelQuery);
            try {
                GeneralRPCProtos.Ack ack = (GeneralRPCProtos.Ack) cancelQuery.checkedGet();
                TestDrillbitResilience.logger.debug("Sleep thread for 0.01 seconds");
                Thread.sleep(10L);
                TestDrillbitResilience.logger.debug("Ack: {}", ack);
            } catch (RpcException e2) {
                this.ex.value = e2;
                TestDrillbitResilience.logger.debug("The query wasn't canceled." + e2);
            } catch (InterruptedException e3) {
                TestDrillbitResilience.logger.debug("Sleep thread interrupted. Ignore it");
            }
            if (this.latch != null) {
                this.latch.countDown();
            }
            TestDrillbitResilience.logger.debug("Finish cancelling thread");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/server/TestDrillbitResilience$ListenerThatCancelsQueryAfterFirstBatchOfData.class */
    public static class ListenerThatCancelsQueryAfterFirstBatchOfData extends WaitUntilCompleteListener {
        private boolean cancelRequested;

        private ListenerThatCancelsQueryAfterFirstBatchOfData() {
            super();
            this.cancelRequested = false;
        }

        @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
        public void dataArrived(QueryDataBatch queryDataBatch, ConnectionThrottle connectionThrottle) {
            if (!this.cancelRequested) {
                TestDrillbitResilience.logger.debug("First batch arrived, so cancelling thread started");
                check(this.queryId != null, "Query id should not be null, since we have waited long enough.", new Object[0]);
                new CancellingThread(this.queryId, this.ex, null).start();
                this.cancelRequested = true;
            }
            queryDataBatch.release();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/server/TestDrillbitResilience$ResumingThread.class */
    public static class ResumingThread extends Thread {
        private final UserBitShared.QueryId queryId;
        private final Pointer<Exception> ex;
        private final ExtendedLatch latch;

        public ResumingThread(UserBitShared.QueryId queryId, Pointer<Exception> pointer, ExtendedLatch extendedLatch) {
            this.queryId = queryId;
            this.ex = pointer;
            this.latch = extendedLatch;
            TestDrillbitResilience.logger.debug("Cancelling thread created");
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            this.latch.awaitUninterruptibly();
            try {
                TestDrillbitResilience.client.client().resumeQuery(this.queryId).checkedGet();
            } catch (RpcException e) {
                this.ex.value = e;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/drill/exec/server/TestDrillbitResilience$WaitUntilCompleteListener.class */
    public static class WaitUntilCompleteListener implements UserResultsListener {
        protected final ExtendedLatch latch;
        protected UserBitShared.QueryId queryId;
        protected volatile Pointer<Exception> ex;
        protected volatile UserBitShared.QueryResult.QueryState state;
        private int count;

        private WaitUntilCompleteListener() {
            this.latch = new ExtendedLatch(1);
            this.queryId = null;
            this.ex = new Pointer<>();
            this.state = null;
            this.count = 0;
        }

        protected final void check(boolean z, String str, Object... objArr) {
            if (z) {
                return;
            }
            this.ex.value = new IllegalStateException(String.format(str, objArr));
        }

        protected final void cancelAndResume() {
            Preconditions.checkNotNull(this.queryId);
            ExtendedLatch extendedLatch = new ExtendedLatch(1);
            new CancellingThread(this.queryId, this.ex, extendedLatch).start();
            new ResumingThread(this.queryId, this.ex, extendedLatch).start();
        }

        public void queryIdArrived(UserBitShared.QueryId queryId) {
            this.queryId = queryId;
        }

        public void submissionFailed(UserException userException) {
            this.ex.value = userException;
            this.state = UserBitShared.QueryResult.QueryState.FAILED;
            this.latch.countDown();
        }

        public void queryCompleted(UserBitShared.QueryResult.QueryState queryState) {
            this.state = queryState;
            this.latch.countDown();
        }

        public void dataArrived(QueryDataBatch queryDataBatch, ConnectionThrottle connectionThrottle) {
            queryDataBatch.release();
        }

        public final Pair<UserBitShared.QueryResult.QueryState, Exception> waitForCompletion() {
            try {
                TestDrillbitResilience.logger.debug("Wait for completion. latch: {}", Long.valueOf(this.latch.getCount()));
                this.latch.await();
            } catch (InterruptedException e) {
                TestDrillbitResilience.logger.error("Interrupted while waiting for event latch");
            }
            TestDrillbitResilience.logger.debug("Completed. Wait finished");
            return new Pair<>(this.state, (Exception) this.ex.value);
        }

        boolean lastDrillbit() {
            int i = this.count + 1;
            this.count = i;
            return i == TestDrillbitResilience.cluster.drillbits().size();
        }
    }

    @BeforeAll
    public static void startSomeDrillbits() throws Exception {
        logFixture = LogFixture.builder().toConsole().logger(TestDrillbitResilience.class, CURRENT_LOG_LEVEL).logger(DrillClient.class, CURRENT_LOG_LEVEL).logger(QueryStateProcessor.class, CURRENT_LOG_LEVEL).logger(WorkManager.class, CURRENT_LOG_LEVEL).logger(UnorderedReceiverBatch.class, CURRENT_LOG_LEVEL).logger(ExtendedLatch.class, CURRENT_LOG_LEVEL).logger(Foreman.class, CURRENT_LOG_LEVEL).logger(QueryStateProcessor.class, CURRENT_LOG_LEVEL).logger(ExecutionControlsInjector.class, CURRENT_LOG_LEVEL).build();
        ZookeeperTestUtil.setJaasTestConfigFile();
        dirTestWatcher.start(TestDrillbitResilience.class);
        startCluster(ClusterFixture.builder(dirTestWatcher).configProperty("drill.exec.http.enabled", false).withBits(DRILLBIT_ALPHA, DRILLBIT_BETA, DRILLBIT_GAMMA));
        clearAllInjections();
        logger.debug("Start 3 drillbits Test Drill cluster: {}, {}, {}", new Object[]{DRILLBIT_ALPHA, DRILLBIT_BETA, DRILLBIT_GAMMA});
    }

    @AfterAll
    public static void tearDownAfterClass() throws Exception {
        logFixture.close();
    }

    private static void clearAllInjections() {
        logger.debug("Clear all injections");
        Preconditions.checkNotNull(client);
        ControlsInjectionUtil.clearControls(client.client());
    }

    private static void assertDrillbitsOk() {
        SingleRowListener singleRowListener = new SingleRowListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.1
            private final BufferAllocator bufferAllocator = RootAllocatorFactory.newRoot(TestDrillbitResilience.cluster.config());
            private final RecordBatchLoader loader = new RecordBatchLoader(this.bufferAllocator);

            @Override // org.apache.drill.SingleRowListener
            public void rowArrived(QueryDataBatch queryDataBatch) {
                this.loader.load(queryDataBatch.getHeader().getDef(), queryDataBatch.getData());
                Assertions.assertEquals(1, this.loader.getRecordCount());
                BatchSchema schema = this.loader.getSchema();
                Assertions.assertEquals(1, schema.getFieldCount());
                Assertions.assertEquals(TypeProtos.MinorType.BIGINT, schema.getColumn(0).getType().getMinorType());
                Object object = ((VectorWrapper) this.loader.iterator().next()).getValueVector().getAccessor().getObject(0);
                Assertions.assertTrue(object instanceof Long);
                Assertions.assertEquals(TestDrillbitResilience.cluster.drillbits().size(), ((Long) object).intValue());
                this.loader.clear();
            }

            @Override // org.apache.drill.SingleRowListener
            public void cleanup() {
                this.loader.clear();
                DrillAutoCloseables.closeNoChecked(this.bufferAllocator);
            }
        };
        try {
            try {
                QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, "select count(*) from sys.memory", singleRowListener);
                singleRowListener.waitForCompletion();
                UserBitShared.QueryResult.QueryState queryState = singleRowListener.getQueryState();
                Assertions.assertSame(queryState, UserBitShared.QueryResult.QueryState.COMPLETED, () -> {
                    return String.format("QueryState should be COMPLETED (and not %s).", queryState);
                });
                Assertions.assertTrue(singleRowListener.getErrorList().isEmpty(), "There should not be any errors when checking if Drillbits are OK");
                logger.debug("Cleanup listener");
                singleRowListener.cleanup();
                logger.debug("Drillbits are ok.");
            } catch (Exception e) {
                throw new RuntimeException("Couldn't query active drillbits", e);
            }
        } catch (Throwable th) {
            logger.debug("Cleanup listener");
            singleRowListener.cleanup();
            throw th;
        }
    }

    @BeforeEach
    void setUp(TestInfo testInfo) {
        String name = ((Method) testInfo.getTestMethod().orElseThrow(() -> {
            return new TestInstantiationException("Can't get method neme");
        })).getName();
        String displayName = testInfo.getDisplayName();
        if (displayName.startsWith("repetition")) {
            logger.debug("{} for {} test started", displayName, name);
        } else {
            logger.debug("{} test started", name);
        }
    }

    @AfterEach
    public void checkDrillbits(TestInfo testInfo) {
        clearAllInjections();
        assertDrillbitsOk();
        String name = ((Method) testInfo.getTestMethod().orElseThrow(() -> {
            return new TestInstantiationException("Can't get method neme");
        })).getName();
        String displayName = testInfo.getDisplayName();
        if (displayName.startsWith("repetition")) {
            logger.debug("{} for {} test finished", displayName, name);
        } else {
            logger.debug("{} test finished", name);
        }
    }

    @Timeout(15)
    @Test
    public void settingNoOpInjectionsAndQuery() {
        long countAllocatedMemory = countAllocatedMemory();
        setControls(Controls.newBuilder().addExceptionOnBit(getClass(), "noop", RuntimeException.class, cluster.drillbit(DRILLBIT_BETA).getContext().getEndpoint()).build());
        WaitUntilCompleteListener waitUntilCompleteListener = new WaitUntilCompleteListener();
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, TEST_QUERY, waitUntilCompleteListener);
        assertStateCompleted(waitUntilCompleteListener.waitForCompletion(), UserBitShared.QueryResult.QueryState.COMPLETED);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void foreman_runTryBeginning() {
        long countAllocatedMemory = countAllocatedMemory();
        testForeman("run-try-beginning");
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void foreman_runTryEnd() {
        long countAllocatedMemory = countAllocatedMemory();
        testForeman("run-try-end");
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @Test
    public void passThrough() {
        long countAllocatedMemory = countAllocatedMemory();
        WaitUntilCompleteListener waitUntilCompleteListener = new WaitUntilCompleteListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.2
            @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
            public void queryIdArrived(UserBitShared.QueryId queryId) {
                super.queryIdArrived(queryId);
                ExtendedLatch extendedLatch = new ExtendedLatch(1);
                new ResumingThread(queryId, this.ex, extendedLatch).start();
                extendedLatch.countDown();
            }
        };
        setControls(Controls.newBuilder().addPause(PojoRecordReader.class, "read-next").build());
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, TEST_QUERY, waitUntilCompleteListener);
        assertStateCompleted(waitUntilCompleteListener.waitForCompletion(), UserBitShared.QueryResult.QueryState.COMPLETED);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void cancelWhenQueryIdArrives() {
        long countAllocatedMemory = countAllocatedMemory();
        assertCancelledWithoutException(Controls.newBuilder().addPause(FragmentExecutor.class, "fragment-running").build(), new WaitUntilCompleteListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.3
            @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
            public void queryIdArrived(UserBitShared.QueryId queryId) {
                super.queryIdArrived(queryId);
                cancelAndResume();
            }
        });
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void cancelInMiddleOfFetchingResults() {
        long countAllocatedMemory = countAllocatedMemory();
        assertCancelledWithoutException(Controls.newBuilder().addPause(ScreenCreator.class, "sending-data", 1).build(), new WaitUntilCompleteListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.4
            private boolean cancelRequested = false;

            @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
            public void dataArrived(QueryDataBatch queryDataBatch, ConnectionThrottle connectionThrottle) {
                if (!this.cancelRequested) {
                    check(this.queryId != null, "Query id should not be null, since we have waited long enough.", new Object[0]);
                    cancelAndResume();
                    this.cancelRequested = true;
                }
                queryDataBatch.release();
            }
        });
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void cancelAfterAllResultsProduced() {
        long countAllocatedMemory = countAllocatedMemory();
        assertCancelledWithoutException(Controls.newBuilder().addPause(ScreenCreator.class, "send-complete").build(), new WaitUntilCompleteListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.5
            @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
            public void dataArrived(QueryDataBatch queryDataBatch, ConnectionThrottle connectionThrottle) {
                if (lastDrillbit()) {
                    check(this.queryId != null, "Query id should not be null, since we have waited long enough.", new Object[0]);
                    cancelAndResume();
                }
                queryDataBatch.release();
            }
        });
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void cancelAfterEverythingIsCompleted() {
        long countAllocatedMemory = countAllocatedMemory();
        assertCompletedWithoutException(Controls.newBuilder().addPause(Foreman.class, "foreman-cleanup").build(), new WaitUntilCompleteListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.6
            @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
            public void dataArrived(QueryDataBatch queryDataBatch, ConnectionThrottle connectionThrottle) {
                if (lastDrillbit()) {
                    check(this.queryId != null, "Query id should not be null, since we have waited long enough.", new Object[0]);
                    cancelAndResume();
                }
                queryDataBatch.release();
            }
        });
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @Test
    public void successfullyCompletes() {
        long countAllocatedMemory = countAllocatedMemory();
        WaitUntilCompleteListener waitUntilCompleteListener = new WaitUntilCompleteListener();
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, TEST_QUERY, waitUntilCompleteListener);
        assertStateCompleted(waitUntilCompleteListener.waitForCompletion(), UserBitShared.QueryResult.QueryState.COMPLETED);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @Test
    public void failsWhenParsing() {
        long countAllocatedMemory = countAllocatedMemory();
        assertFailsWithException(Controls.newBuilder().addException(DrillSqlWorker.class, "sql-parsing", ForemanSetupException.class, 0, 2).build(), ForemanSetupException.class, "sql-parsing", TEST_QUERY);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @Test
    public void failsWhenSendingFragments() {
        long countAllocatedMemory = countAllocatedMemory();
        assertFailsWithException(Controls.newBuilder().addException(FragmentsRunner.class, "send-fragments", ForemanException.class).build(), ForemanException.class, "send-fragments", TEST_QUERY);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @Test
    public void failsDuringExecution() {
        long countAllocatedMemory = countAllocatedMemory();
        assertFailsWithException(Controls.newBuilder().addException(FragmentExecutor.class, "fragment-execution", IOException.class).build(), IOException.class, "fragment-execution", TEST_QUERY);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void interruptingBlockedMergingRecordBatch() {
        long countAllocatedMemory = countAllocatedMemory();
        interruptingBlockedFragmentsWaitingForData(Controls.newBuilder().addPause(MergingRecordBatch.class, "waiting-for-data", 1).build());
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void interruptingBlockedUnorderedReceiverBatch() {
        long countAllocatedMemory = countAllocatedMemory();
        String build = Controls.newBuilder().addPause(UnorderedReceiverBatch.class, "waiting-for-data", 1).build();
        logger.debug("Start interruptingBlockedFragmentsWaitingForData");
        interruptingBlockedFragmentsWaitingForData(build);
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void interruptingPartitionerThreadFragment() {
        try {
            client.alterSession("planner.slice_target", "1");
            client.alterSession("planner.enable_hashagg", "true");
            client.alterSession(PlannerSettings.PARTITION_SENDER_SET_THREADS.getOptionName(), "6");
            long countAllocatedMemory = countAllocatedMemory();
            assertCancelledWithoutException(Controls.newBuilder().addLatch(PartitionerDecorator.class, "partitioner-sender-latch").addPause(PartitionerDecorator.class, "wait-for-fragment-interrupt", 1).build(), new ListenerThatCancelsQueryAfterFirstBatchOfData(), "SELECT sales_city, COUNT(*) cnt FROM cp.`region.json` GROUP BY sales_city");
            long countAllocatedMemory2 = countAllocatedMemory();
            Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
                return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
            });
            client.resetSession("planner.slice_target");
            client.resetSession("planner.enable_hashagg");
            client.resetSession(PlannerSettings.PARTITION_SENDER_SET_THREADS.getOptionName());
        } catch (Throwable th) {
            client.resetSession("planner.slice_target");
            client.resetSession("planner.enable_hashagg");
            client.resetSession(PlannerSettings.PARTITION_SENDER_SET_THREADS.getOptionName());
            throw th;
        }
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void interruptingWhileFragmentIsBlockedInAcquiringSendingTicket() {
        long countAllocatedMemory = countAllocatedMemory();
        assertCancelledWithoutException(Controls.newBuilder().addPause(SingleSenderCreator.SingleSenderRootExec.class, "data-tunnel-send-batch-wait-for-interrupt", 1).build(), new ListenerThatCancelsQueryAfterFirstBatchOfData());
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void memoryLeaksWhenCancelled() {
        client.alterSession("planner.slice_target", "10");
        long countAllocatedMemory = countAllocatedMemory();
        try {
            String build = Controls.newBuilder().addPause(ScreenCreator.class, "sending-data", 1).build();
            String str = null;
            try {
                String file = BaseTestQuery.getFile("queries/tpch/09.sql");
                str = file.substring(0, file.length() - 1);
            } catch (IOException e) {
                Assertions.fail("Failed to get query file", e);
            }
            assertCancelledWithoutException(build, new WaitUntilCompleteListener() { // from class: org.apache.drill.exec.server.TestDrillbitResilience.7
                private volatile boolean cancelRequested = false;

                @Override // org.apache.drill.exec.server.TestDrillbitResilience.WaitUntilCompleteListener
                public void dataArrived(QueryDataBatch queryDataBatch, ConnectionThrottle connectionThrottle) {
                    if (!this.cancelRequested) {
                        check(this.queryId != null, "Query id should not be null, since we have waited long enough.", new Object[0]);
                        cancelAndResume();
                        this.cancelRequested = true;
                    }
                    queryDataBatch.release();
                }
            }, str);
            long countAllocatedMemory2 = countAllocatedMemory();
            Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
                return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
            });
            client.resetSession("planner.slice_target");
        } catch (Throwable th) {
            client.resetSession("planner.slice_target");
            throw th;
        }
    }

    @Timeout(15)
    @RepeatedTest(3)
    public void memoryLeaksWhenFailed() {
        client.alterSession("planner.slice_target", "10");
        long countAllocatedMemory = countAllocatedMemory();
        try {
            String build = Controls.newBuilder().addException(FragmentExecutor.class, "fragment-execution", IOException.class).build();
            String str = null;
            try {
                String file = BaseTestQuery.getFile("queries/tpch/09.sql");
                str = file.substring(0, file.length() - 1);
            } catch (IOException e) {
                Assertions.fail("Failed to get query file: " + e);
            }
            assertFailsWithException(build, IOException.class, "fragment-execution", str);
            long countAllocatedMemory2 = countAllocatedMemory();
            Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
                return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
            });
            client.resetSession("planner.slice_target");
        } catch (Throwable th) {
            client.resetSession("planner.slice_target");
            throw th;
        }
    }

    @Timeout(15)
    @Test
    public void failsAfterMSorterSorting() {
        long countAllocatedMemory = countAllocatedMemory();
        assertFailsWithException(Controls.newBuilder().addException(ExternalSortBatch.class, "after-sort", RuntimeException.class).build(), RuntimeException.class, "after-sort", "select n_name from cp.`tpch/lineitem.parquet` order by n_name");
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    @Timeout(15)
    @Test
    public void failsAfterMSorterSetup() {
        long countAllocatedMemory = countAllocatedMemory();
        assertFailsWithException(Controls.newBuilder().addException(ExternalSortBatch.class, "after-setup", RuntimeException.class).build(), RuntimeException.class, "after-setup", "select n_name from cp.`tpch/lineitem.parquet` order by n_name");
        long countAllocatedMemory2 = countAllocatedMemory();
        Assertions.assertEquals(countAllocatedMemory, countAllocatedMemory2, () -> {
            return String.format("We are leaking %d bytes", Long.valueOf(countAllocatedMemory2 - countAllocatedMemory));
        });
    }

    private void assertStateCompleted(Pair<UserBitShared.QueryResult.QueryState, Exception> pair, UserBitShared.QueryResult.QueryState queryState) {
        UserBitShared.QueryResult.QueryState queryState2 = (UserBitShared.QueryResult.QueryState) pair.getFirst();
        Exception exc = (Exception) pair.getSecond();
        if (queryState2 == queryState && exc == null) {
            return;
        }
        Object[] objArr = new Object[3];
        objArr[0] = queryState;
        objArr[1] = queryState2;
        objArr[2] = exc == null ? "none." : exc;
        Assertions.fail(String.format("Query state is incorrect (expected: %s, actual: %s) AND/OR \nException thrown: %s", objArr));
    }

    private void assertCancelledWithoutException(String str, WaitUntilCompleteListener waitUntilCompleteListener, String str2) {
        setControls(str);
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, str2, waitUntilCompleteListener);
        assertStateCompleted(waitUntilCompleteListener.waitForCompletion(), UserBitShared.QueryResult.QueryState.CANCELED);
    }

    private void assertCompletedWithoutException(String str, WaitUntilCompleteListener waitUntilCompleteListener) {
        setControls(str);
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, TEST_QUERY, waitUntilCompleteListener);
        assertStateCompleted(waitUntilCompleteListener.waitForCompletion(), UserBitShared.QueryResult.QueryState.COMPLETED);
    }

    private void assertCancelledWithoutException(String str, WaitUntilCompleteListener waitUntilCompleteListener) {
        assertCancelledWithoutException(str, waitUntilCompleteListener, TEST_QUERY);
    }

    private long countAllocatedMemory() {
        try {
            logger.debug("Sleep thread for 2 seconds");
            Thread.sleep(1500L);
        } catch (InterruptedException e) {
            logger.debug("Sleep thread interrupted. Ignore it", e);
        }
        long j = 0;
        Iterator<Drillbit> it = cluster.drillbits().iterator();
        while (it.hasNext()) {
            j += it.next().getContext().getAllocator().getAllocatedMemory();
        }
        logger.debug("Allocated memory: " + j);
        return j;
    }

    private void interruptingBlockedFragmentsWaitingForData(String str) {
        try {
            client.alterSession("planner.slice_target", "1");
            client.alterSession("planner.enable_hashagg", "false");
            assertCancelledWithoutException(str, new ListenerThatCancelsQueryAfterFirstBatchOfData(), "SELECT sales_city, COUNT(*) cnt FROM cp.`region.json` GROUP BY sales_city");
            client.resetSession("planner.slice_target");
            client.resetSession("planner.enable_hashagg");
        } catch (Throwable th) {
            client.resetSession("planner.slice_target");
            client.resetSession("planner.enable_hashagg");
            throw th;
        }
    }

    private void assertFailsWithException(String str, Class<? extends Throwable> cls, String str2, String str3) {
        setControls(str);
        WaitUntilCompleteListener waitUntilCompleteListener = new WaitUntilCompleteListener();
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, str3, waitUntilCompleteListener);
        Pair<UserBitShared.QueryResult.QueryState, Exception> waitForCompletion = waitUntilCompleteListener.waitForCompletion();
        UserBitShared.QueryResult.QueryState queryState = (UserBitShared.QueryResult.QueryState) waitForCompletion.getFirst();
        Assertions.assertSame(queryState, UserBitShared.QueryResult.QueryState.FAILED, () -> {
            return String.format("Query state should be FAILED (and not %s).", queryState);
        });
        assertExceptionMessage((Throwable) waitForCompletion.getSecond(), cls, str2);
    }

    private void assertFailsOrCompletedWithException(String str, Class<? extends Throwable> cls, String str2, String str3) {
        setControls(str);
        WaitUntilCompleteListener waitUntilCompleteListener = new WaitUntilCompleteListener();
        QueryTestUtil.testWithListener(client.client(), UserBitShared.QueryType.SQL, str3, waitUntilCompleteListener);
        Pair<UserBitShared.QueryResult.QueryState, Exception> waitForCompletion = waitUntilCompleteListener.waitForCompletion();
        UserBitShared.QueryResult.QueryState queryState = (UserBitShared.QueryResult.QueryState) waitForCompletion.getFirst();
        Assertions.assertTrue(queryState.equals(UserBitShared.QueryResult.QueryState.FAILED) || queryState.equals(UserBitShared.QueryResult.QueryState.COMPLETED), () -> {
            return String.format("Query state should be FAILED (and not %s).", queryState);
        });
        assertExceptionMessage((Throwable) waitForCompletion.getSecond(), cls, str2);
    }

    private void testForeman(String str) {
        assertFailsOrCompletedWithException(Controls.newBuilder().addException(Foreman.class, str, ForemanException.class).build(), ForemanException.class, str, TEST_QUERY);
    }

    private void setControls(String str) {
        ControlsInjectionUtil.setControls(client.client(), str);
    }

    private void assertExceptionMessage(Throwable th, Class<? extends Throwable> cls, String str) {
        Assertions.assertTrue(th instanceof UserException, "Throwable was not of UserException type");
        UserBitShared.ExceptionWrapper exception = ((UserException) th).getOrCreatePBError(false).getException();
        Assertions.assertEquals(cls.getName(), exception.getExceptionClass(), "Exception class names should match");
        Assertions.assertEquals(str, exception.getMessage(), "Exception sites should match.");
    }
}
