package org.apache.spark.network.shuffle;

import java.util.Arrays;
import java.util.HashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.crypto.AuthServerBootstrap;
import org.apache.spark.network.sasl.SaslClientBootstrap;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/spark/network/shuffle/AppIsolationSuite.class */
public class AppIsolationSuite {
    private static final long TIMEOUT_MS = 10000;
    private static SecretKeyHolder secretKeyHolder;
    private static TransportConf conf;

    @BeforeClass
    public static void beforeAll() {
        HashMap hashMap = new HashMap();
        hashMap.put("spark.network.crypto.enabled", "true");
        hashMap.put("spark.network.crypto.saslFallback", "false");
        conf = new TransportConf("shuffle", new MapConfigProvider(hashMap));
        secretKeyHolder = (SecretKeyHolder) Mockito.mock(SecretKeyHolder.class);
        Mockito.when(secretKeyHolder.getSaslUser((String) Mockito.eq("app-1"))).thenReturn("app-1");
        Mockito.when(secretKeyHolder.getSecretKey((String) Mockito.eq("app-1"))).thenReturn("app-1");
        Mockito.when(secretKeyHolder.getSaslUser((String) Mockito.eq("app-2"))).thenReturn("app-2");
        Mockito.when(secretKeyHolder.getSecretKey((String) Mockito.eq("app-2"))).thenReturn("app-2");
    }

    @Test
    public void testSaslAppIsolation() throws Exception {
        testAppIsolation(() -> {
            return new SaslServerBootstrap(conf, secretKeyHolder);
        }, str -> {
            return new SaslClientBootstrap(conf, str, secretKeyHolder);
        });
    }

    @Test
    public void testAuthEngineAppIsolation() throws Exception {
        testAppIsolation(() -> {
            return new AuthServerBootstrap(conf, secretKeyHolder);
        }, str -> {
            return new AuthClientBootstrap(conf, str, secretKeyHolder);
        });
    }

    private void testAppIsolation(Supplier<TransportServerBootstrap> supplier, Function<String, TransportClientBootstrap> function) throws Exception {
        ExternalShuffleBlockHandler externalShuffleBlockHandler = new ExternalShuffleBlockHandler(new OneForOneStreamManager(), (ExternalShuffleBlockResolver) Mockito.mock(ExternalShuffleBlockResolver.class));
        TransportServerBootstrap transportServerBootstrap = supplier.get();
        TransportContext transportContext = new TransportContext(conf, externalShuffleBlockHandler);
        TransportServer createServer = transportContext.createServer(Arrays.asList(transportServerBootstrap));
        try {
            TransportClientFactory createClientFactory = transportContext.createClientFactory(Arrays.asList(function.apply("app-1")));
            try {
                TransportClient createClient = createClientFactory.createClient(TestUtils.getLocalHost(), createServer.getPort());
                try {
                    final AtomicReference atomicReference = new AtomicReference();
                    final CountDownLatch countDownLatch = new CountDownLatch(1);
                    BlockFetchingListener blockFetchingListener = new BlockFetchingListener() { // from class: org.apache.spark.network.shuffle.AppIsolationSuite.1
                        public void onBlockFetchSuccess(String str, ManagedBuffer managedBuffer) {
                            countDownLatch.countDown();
                        }

                        public void onBlockFetchFailure(String str, Throwable th) {
                            atomicReference.set(th);
                            countDownLatch.countDown();
                        }
                    };
                    String[] strArr = {"shuffle_0_1_2", "shuffle_0_3_4"};
                    new OneForOneBlockFetcher(createClient, "app-2", "0", strArr, blockFetchingListener, conf).start();
                    countDownLatch.await();
                    checkSecurityException((Throwable) atomicReference.get());
                    createClient.sendRpcSync(new RegisterExecutor("app-1", "0", new ExecutorShuffleInfo(new String[]{System.getProperty("java.io.tmpdir")}, 1, "org.apache.spark.shuffle.sort.SortShuffleManager")).toByteBuffer(), TIMEOUT_MS);
                    long j = BlockTransferMessage.Decoder.fromByteBuffer(createClient.sendRpcSync(new OpenBlocks("app-1", "0", strArr).toByteBuffer(), TIMEOUT_MS)).streamId;
                    TransportClientFactory createClientFactory2 = transportContext.createClientFactory(Arrays.asList(function.apply("app-2")));
                    try {
                        TransportClient createClient2 = createClientFactory2.createClient(TestUtils.getLocalHost(), createServer.getPort());
                        try {
                            final CountDownLatch countDownLatch2 = new CountDownLatch(1);
                            ChunkReceivedCallback chunkReceivedCallback = new ChunkReceivedCallback() { // from class: org.apache.spark.network.shuffle.AppIsolationSuite.2
                                public void onSuccess(int i, ManagedBuffer managedBuffer) {
                                    countDownLatch2.countDown();
                                }

                                public void onFailure(int i, Throwable th) {
                                    atomicReference.set(th);
                                    countDownLatch2.countDown();
                                }
                            };
                            atomicReference.set(null);
                            createClient2.fetchChunk(j, 0, chunkReceivedCallback);
                            countDownLatch2.await();
                            checkSecurityException((Throwable) atomicReference.get());
                            if (createClient2 != null) {
                                createClient2.close();
                            }
                            if (createClientFactory2 != null) {
                                createClientFactory2.close();
                            }
                            if (createClient != null) {
                                createClient.close();
                            }
                            if (createClientFactory != null) {
                                createClientFactory.close();
                            }
                            if (createServer != null) {
                                createServer.close();
                            }
                        } catch (Throwable th) {
                            if (createClient2 != null) {
                                try {
                                    createClient2.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } catch (Throwable th3) {
                        if (createClientFactory2 != null) {
                            try {
                                createClientFactory2.close();
                            } catch (Throwable th4) {
                                th3.addSuppressed(th4);
                            }
                        }
                        throw th3;
                    }
                } catch (Throwable th5) {
                    if (createClient != null) {
                        try {
                            createClient.close();
                        } catch (Throwable th6) {
                            th5.addSuppressed(th6);
                        }
                    }
                    throw th5;
                }
            } catch (Throwable th7) {
                if (createClientFactory != null) {
                    try {
                        createClientFactory.close();
                    } catch (Throwable th8) {
                        th7.addSuppressed(th8);
                    }
                }
                throw th7;
            }
        } catch (Throwable th9) {
            if (createServer != null) {
                try {
                    createServer.close();
                } catch (Throwable th10) {
                    th9.addSuppressed(th10);
                }
            }
            throw th9;
        }
    }

    private static void checkSecurityException(Throwable th) {
        Assert.assertNotNull("No exception was caught.", th);
        Assert.assertTrue("Expected SecurityException.", th.getMessage().contains(SecurityException.class.getName()));
    }
}
