package org.apache.spark.network;

import com.google.common.collect.Sets;
import com.google.common.io.Files;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.network.StreamSuite;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/network/RpcIntegrationSuite.class */
public class RpcIntegrationSuite {
    static TransportConf conf;
    static TransportContext context;
    static TransportServer server;
    static TransportClientFactory clientFactory;
    static RpcHandler rpcHandler;
    static List<String> oneWayMsgs;
    static StreamTestHelper testData;
    static ConcurrentHashMap<String, VerifyingStreamCallback> streamCallbacks = new ConcurrentHashMap<>();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/spark/network/RpcIntegrationSuite$RpcResult.class */
    public static class RpcResult {
        public Set<String> successMessages;
        public Set<String> errorMessages;

        RpcResult() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/network/RpcIntegrationSuite$RpcStreamCallback.class */
    public static class RpcStreamCallback implements RpcResponseCallback {
        final String streamId;
        final RpcResult res;
        final Semaphore sem;

        RpcStreamCallback(String str, RpcResult rpcResult, Semaphore semaphore) {
            this.streamId = str;
            this.res = rpcResult;
            this.sem = semaphore;
        }

        public void onSuccess(ByteBuffer byteBuffer) {
            this.res.successMessages.add(this.streamId);
            this.sem.release();
        }

        public void onFailure(Throwable th) {
            this.res.errorMessages.add(th.getMessage());
            this.sem.release();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/network/RpcIntegrationSuite$VerifyingStreamCallback.class */
    public static class VerifyingStreamCallback implements StreamCallbackWithID {
        final String streamId;
        final StreamSuite.TestCallback helper;
        final OutputStream out;
        final File outFile;

        VerifyingStreamCallback(String str) throws IOException {
            if (str.equals("file")) {
                this.outFile = File.createTempFile("data", ".tmp", RpcIntegrationSuite.testData.tempDir);
                this.out = new FileOutputStream(this.outFile);
            } else {
                this.out = new ByteArrayOutputStream();
                this.outFile = null;
            }
            this.streamId = str;
            this.helper = new StreamSuite.TestCallback(this.out);
        }

        void verify() throws IOException {
            ByteBuffer duplicate;
            if (this.streamId.equals("file")) {
                Assert.assertTrue("File stream did not match.", Files.equal(RpcIntegrationSuite.testData.testFile, this.outFile));
                return;
            }
            byte[] byteArray = ((ByteArrayOutputStream) this.out).toByteArray();
            ByteBuffer srcBuffer = RpcIntegrationSuite.testData.srcBuffer(this.streamId);
            synchronized (srcBuffer) {
                duplicate = srcBuffer.duplicate();
            }
            byte[] bArr = new byte[duplicate.remaining()];
            duplicate.get(bArr);
            Assert.assertEquals(bArr.length, byteArray.length);
            Assert.assertTrue("buffers don't match", Arrays.equals(bArr, byteArray));
        }

        public void onData(String str, ByteBuffer byteBuffer) throws IOException {
            this.helper.onData(str, byteBuffer);
        }

        public void onComplete(String str) throws IOException {
            this.helper.onComplete(str);
        }

        public void onFailure(String str, Throwable th) throws IOException {
            this.helper.onFailure(str, th);
        }

        public String getID() {
            return this.streamId;
        }
    }

    @BeforeClass
    public static void setUp() throws Exception {
        conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
        testData = new StreamTestHelper();
        rpcHandler = new RpcHandler() { // from class: org.apache.spark.network.RpcIntegrationSuite.1
            public void receive(TransportClient transportClient, ByteBuffer byteBuffer, RpcResponseCallback rpcResponseCallback) {
                String[] split = JavaUtils.bytesToString(byteBuffer).split("/");
                if (split[0].equals("hello")) {
                    rpcResponseCallback.onSuccess(JavaUtils.stringToBytes("Hello, " + split[1] + "!"));
                } else if (split[0].equals("return error")) {
                    rpcResponseCallback.onFailure(new RuntimeException("Returned: " + split[1]));
                } else if (split[0].equals("throw error")) {
                    throw new RuntimeException("Thrown: " + split[1]);
                }
            }

            public StreamCallbackWithID receiveStream(TransportClient transportClient, ByteBuffer byteBuffer, RpcResponseCallback rpcResponseCallback) {
                return RpcIntegrationSuite.receiveStreamHelper(JavaUtils.bytesToString(byteBuffer));
            }

            public void receive(TransportClient transportClient, ByteBuffer byteBuffer) {
                RpcIntegrationSuite.oneWayMsgs.add(JavaUtils.bytesToString(byteBuffer));
            }

            public StreamManager getStreamManager() {
                return new OneForOneStreamManager();
            }
        };
        context = new TransportContext(conf, rpcHandler);
        server = context.createServer();
        clientFactory = context.createClientFactory();
        oneWayMsgs = new ArrayList();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static StreamCallbackWithID receiveStreamHelper(final String str) {
        try {
            if (!str.startsWith("fail/")) {
                VerifyingStreamCallback verifyingStreamCallback = new VerifyingStreamCallback(str);
                streamCallbacks.put(str, verifyingStreamCallback);
                return verifyingStreamCallback;
            }
            String str2 = str.split("/")[1];
            boolean z = -1;
            switch (str2.hashCode()) {
                case -1813878457:
                    if (str2.equals("exception-ondata")) {
                        z = false;
                        break;
                    }
                    break;
                case 3392903:
                    if (str2.equals("null")) {
                        z = 2;
                        break;
                    }
                    break;
                case 272143446:
                    if (str2.equals("exception-oncomplete")) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return new StreamCallbackWithID() { // from class: org.apache.spark.network.RpcIntegrationSuite.2
                        public void onData(String str3, ByteBuffer byteBuffer) throws IOException {
                            throw new IOException("failed to read stream data!");
                        }

                        public void onComplete(String str3) throws IOException {
                        }

                        public void onFailure(String str3, Throwable th) throws IOException {
                        }

                        public String getID() {
                            return str;
                        }
                    };
                case true:
                    return new StreamCallbackWithID() { // from class: org.apache.spark.network.RpcIntegrationSuite.3
                        public void onData(String str3, ByteBuffer byteBuffer) throws IOException {
                        }

                        public void onComplete(String str3) throws IOException {
                            throw new IOException("exception in onComplete");
                        }

                        public void onFailure(String str3, Throwable th) throws IOException {
                        }

                        public String getID() {
                            return str;
                        }
                    };
                case true:
                    return null;
                default:
                    throw new IllegalArgumentException("unexpected msg: " + str);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @AfterClass
    public static void tearDown() {
        server.close();
        clientFactory.close();
        context.close();
        testData.cleanup();
    }

    private RpcResult sendRPC(String... strArr) throws Exception {
        TransportClient createClient = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
        final Semaphore semaphore = new Semaphore(0);
        final RpcResult rpcResult = new RpcResult();
        rpcResult.successMessages = Collections.synchronizedSet(new HashSet());
        rpcResult.errorMessages = Collections.synchronizedSet(new HashSet());
        RpcResponseCallback rpcResponseCallback = new RpcResponseCallback() { // from class: org.apache.spark.network.RpcIntegrationSuite.4
            public void onSuccess(ByteBuffer byteBuffer) {
                rpcResult.successMessages.add(JavaUtils.bytesToString(byteBuffer));
                semaphore.release();
            }

            public void onFailure(Throwable th) {
                rpcResult.errorMessages.add(th.getMessage());
                semaphore.release();
            }
        };
        for (String str : strArr) {
            createClient.sendRpc(JavaUtils.stringToBytes(str), rpcResponseCallback);
        }
        if (!semaphore.tryAcquire(strArr.length, 5L, TimeUnit.SECONDS)) {
            Assert.fail("Timeout getting response from the server");
        }
        createClient.close();
        return rpcResult;
    }

    private RpcResult sendRpcWithStream(String... strArr) throws Exception {
        TransportClient createClient = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
        Semaphore semaphore = new Semaphore(0);
        RpcResult rpcResult = new RpcResult();
        rpcResult.successMessages = Collections.synchronizedSet(new HashSet());
        rpcResult.errorMessages = Collections.synchronizedSet(new HashSet());
        for (String str : strArr) {
            int lastIndexOf = str.lastIndexOf(47);
            createClient.uploadStream(new NioManagedBuffer(JavaUtils.stringToBytes(str)), testData.openStream(conf, lastIndexOf == -1 ? str : str.substring(lastIndexOf + 1)), new RpcStreamCallback(str, rpcResult, semaphore));
        }
        if (!semaphore.tryAcquire(strArr.length, 5L, TimeUnit.SECONDS)) {
            Assert.fail("Timeout getting response from the server");
        }
        streamCallbacks.values().forEach(verifyingStreamCallback -> {
            try {
                verifyingStreamCallback.verify();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        });
        createClient.close();
        return rpcResult;
    }

    @Test
    public void singleRPC() throws Exception {
        RpcResult sendRPC = sendRPC("hello/Aaron");
        Assert.assertEquals(Sets.newHashSet(new String[]{"Hello, Aaron!"}), sendRPC.successMessages);
        Assert.assertTrue(sendRPC.errorMessages.isEmpty());
    }

    @Test
    public void doubleRPC() throws Exception {
        RpcResult sendRPC = sendRPC("hello/Aaron", "hello/Reynold");
        Assert.assertEquals(Sets.newHashSet(new String[]{"Hello, Aaron!", "Hello, Reynold!"}), sendRPC.successMessages);
        Assert.assertTrue(sendRPC.errorMessages.isEmpty());
    }

    @Test
    public void returnErrorRPC() throws Exception {
        RpcResult sendRPC = sendRPC("return error/OK");
        Assert.assertTrue(sendRPC.successMessages.isEmpty());
        assertErrorsContain(sendRPC.errorMessages, Sets.newHashSet(new String[]{"Returned: OK"}));
    }

    @Test
    public void throwErrorRPC() throws Exception {
        RpcResult sendRPC = sendRPC("throw error/uh-oh");
        Assert.assertTrue(sendRPC.successMessages.isEmpty());
        assertErrorsContain(sendRPC.errorMessages, Sets.newHashSet(new String[]{"Thrown: uh-oh"}));
    }

    @Test
    public void doubleTrouble() throws Exception {
        RpcResult sendRPC = sendRPC("return error/OK", "throw error/uh-oh");
        Assert.assertTrue(sendRPC.successMessages.isEmpty());
        assertErrorsContain(sendRPC.errorMessages, Sets.newHashSet(new String[]{"Returned: OK", "Thrown: uh-oh"}));
    }

    @Test
    public void sendSuccessAndFailure() throws Exception {
        RpcResult sendRPC = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
        Assert.assertEquals(Sets.newHashSet(new String[]{"Hello, Bob!", "Hello, Builder!"}), sendRPC.successMessages);
        assertErrorsContain(sendRPC.errorMessages, Sets.newHashSet(new String[]{"Thrown: the", "Returned: !"}));
    }

    @Test
    public void sendOneWayMessage() throws Exception {
        TransportClient createClient = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
        try {
            createClient.send(JavaUtils.stringToBytes("no reply"));
            Assert.assertEquals(0L, createClient.getHandler().numOutstandingRequests());
            long nanoTime = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10L, TimeUnit.SECONDS);
            while (System.nanoTime() < nanoTime && oneWayMsgs.size() == 0) {
                TimeUnit.MILLISECONDS.sleep(10L);
            }
            Assert.assertEquals(1L, oneWayMsgs.size());
            Assert.assertEquals("no reply", oneWayMsgs.get(0));
            createClient.close();
        } catch (Throwable th) {
            createClient.close();
            throw th;
        }
    }

    @Test
    public void sendRpcWithStreamOneAtATime() throws Exception {
        for (String str : StreamTestHelper.STREAMS) {
            RpcResult sendRpcWithStream = sendRpcWithStream(str);
            Assert.assertTrue("there were error messages!" + sendRpcWithStream.errorMessages, sendRpcWithStream.errorMessages.isEmpty());
            Assert.assertEquals(Sets.newHashSet(new String[]{str}), sendRpcWithStream.successMessages);
        }
    }

    @Test
    public void sendRpcWithStreamConcurrently() throws Exception {
        String[] strArr = new String[10];
        for (int i = 0; i < 10; i++) {
            strArr[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length];
        }
        RpcResult sendRpcWithStream = sendRpcWithStream(strArr);
        Assert.assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), sendRpcWithStream.successMessages);
        Assert.assertTrue(sendRpcWithStream.errorMessages.isEmpty());
    }

    @Test
    public void sendRpcWithStreamFailures() throws Exception {
        RpcResult sendRpcWithStream = sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer");
        assertErrorAndClosed(sendRpcWithStream, "Destination failed while reading stream");
        sendRpcWithStream("fail/null/smallBuffer", "smallBuffer");
        assertErrorAndClosed(sendRpcWithStream, "Destination failed while reading stream");
        RpcResult sendRpcWithStream2 = sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer");
        assertErrorsContain(sendRpcWithStream2.errorMessages, Sets.newHashSet(new String[]{"Failure post-processing"}));
        Assert.assertEquals(Sets.newHashSet(new String[]{"smallBuffer"}), sendRpcWithStream2.successMessages);
    }

    private void assertErrorsContain(Set<String> set, Set<String> set2) {
        Assert.assertEquals("Expected " + set2.size() + " errors, got " + set.size() + "errors: " + set, set2.size(), set.size());
        Pair<Set<String>, Set<String>> checkErrorsContain = checkErrorsContain(set, set2);
        Assert.assertTrue("Could not find error containing " + checkErrorsContain.getRight() + "; errors: " + set, ((Set) checkErrorsContain.getRight()).isEmpty());
        Assert.assertTrue(((Set) checkErrorsContain.getLeft()).isEmpty());
    }

    private void assertErrorAndClosed(RpcResult rpcResult, String str) {
        Assert.assertTrue("unexpected success: " + rpcResult.successMessages, rpcResult.successMessages.isEmpty());
        Set<String> set = rpcResult.errorMessages;
        Assert.assertEquals("Expected 2 errors, got " + set.size() + "errors: " + set, 2L, set.size());
        HashSet newHashSet = Sets.newHashSet(new String[]{"closed", "Connection reset", "java.nio.channels.ClosedChannelException", "java.io.IOException: Broken pipe"});
        Set<String> newHashSet2 = Sets.newHashSet(new String[]{str});
        newHashSet2.addAll(newHashSet);
        Pair<Set<String>, Set<String>> checkErrorsContain = checkErrorsContain(set, newHashSet2);
        Assert.assertTrue("Got a non-empty set " + checkErrorsContain.getLeft(), ((Set) checkErrorsContain.getLeft()).isEmpty());
        Set<String> set2 = (Set) checkErrorsContain.getRight();
        Assert.assertEquals("The size of " + set2 + " was not " + (newHashSet.size() - 1), newHashSet.size() - 1, set2.size());
        for (String str2 : set2) {
            Assert.assertTrue("Found a wrong error " + str2, newHashSet2.contains(str2));
        }
    }

    private Pair<Set<String>, Set<String>> checkErrorsContain(Set<String> set, Set<String> set2) {
        HashSet newHashSet = Sets.newHashSet(set);
        HashSet newHashSet2 = Sets.newHashSet();
        for (String str : set2) {
            Iterator it = newHashSet.iterator();
            boolean z = false;
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (((String) it.next()).contains(str)) {
                    it.remove();
                    z = true;
                    break;
                }
            }
            if (!z) {
                newHashSet2.add(str);
            }
        }
        return new ImmutablePair(newHashSet, newHashSet2);
    }
}
