package org.apache.spark.network;

import io.netty.channel.Channel;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.MergedBlockMetaRequest;
import org.apache.spark.network.protocol.MergedBlockMetaSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamRequest;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.server.ChunkFetchRequestHandler;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportRequestHandler;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/spark/network/TransportRequestHandlerSuite.class */
public class TransportRequestHandlerSuite {
    @Test
    public void handleStreamRequest() throws Exception {
        NoOpRpcHandler noOpRpcHandler = new NoOpRpcHandler();
        OneForOneStreamManager streamManager = noOpRpcHandler.getStreamManager();
        Channel channel = (Channel) Mockito.mock(Channel.class);
        ArrayList arrayList = new ArrayList();
        Mockito.when(channel.writeAndFlush(Mockito.any())).thenAnswer(invocationOnMock -> {
            Object obj = invocationOnMock.getArguments()[0];
            ExtendedChannelPromise extendedChannelPromise = new ExtendedChannelPromise(channel);
            arrayList.add(ImmutablePair.of(obj, extendedChannelPromise));
            return extendedChannelPromise;
        });
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new TestManagedBuffer(10));
        arrayList2.add(new TestManagedBuffer(20));
        arrayList2.add(null);
        arrayList2.add(new TestManagedBuffer(30));
        arrayList2.add(new TestManagedBuffer(40));
        long registerStream = streamManager.registerStream("test-app", arrayList2.iterator(), channel);
        Assert.assertEquals(1L, streamManager.numStreamStates());
        TransportRequestHandler transportRequestHandler = new TransportRequestHandler(channel, (TransportClient) Mockito.mock(TransportClient.class), noOpRpcHandler, 2L, (ChunkFetchRequestHandler) null);
        transportRequestHandler.handle(new StreamRequest(String.format("%d_%d", Long.valueOf(registerStream), 0)));
        Assert.assertEquals(1L, arrayList.size());
        Assert.assertTrue(((Pair) arrayList.get(0)).getLeft() instanceof StreamResponse);
        Assert.assertEquals(arrayList2.get(0), ((StreamResponse) ((Pair) arrayList.get(0)).getLeft()).body());
        transportRequestHandler.handle(new StreamRequest(String.format("%d_%d", Long.valueOf(registerStream), 1)));
        Assert.assertEquals(2L, arrayList.size());
        Assert.assertTrue(((Pair) arrayList.get(1)).getLeft() instanceof StreamResponse);
        Assert.assertEquals(arrayList2.get(1), ((StreamResponse) ((Pair) arrayList.get(1)).getLeft()).body());
        ((ExtendedChannelPromise) ((Pair) arrayList.get(0)).getRight()).finish(true);
        StreamRequest streamRequest = new StreamRequest(String.format("%d_%d", Long.valueOf(registerStream), 2));
        transportRequestHandler.handle(streamRequest);
        Assert.assertEquals(3L, arrayList.size());
        Assert.assertTrue(((Pair) arrayList.get(2)).getLeft() instanceof StreamFailure);
        Assert.assertEquals(String.format("Stream '%s' was not found.", streamRequest.streamId), ((StreamFailure) ((Pair) arrayList.get(2)).getLeft()).error);
        transportRequestHandler.handle(new StreamRequest(String.format("%d_%d", Long.valueOf(registerStream), 3)));
        Assert.assertEquals(4L, arrayList.size());
        Assert.assertTrue(((Pair) arrayList.get(3)).getLeft() instanceof StreamResponse);
        Assert.assertEquals(arrayList2.get(3), ((StreamResponse) ((Pair) arrayList.get(3)).getLeft()).body());
        transportRequestHandler.handle(new StreamRequest(String.format("%d_%d", Long.valueOf(registerStream), 4)));
        ((Channel) Mockito.verify(channel, Mockito.times(1))).close();
        Assert.assertEquals(4L, arrayList.size());
        streamManager.connectionTerminated(channel);
        Assert.assertEquals(0L, streamManager.numStreamStates());
    }

    @Test
    public void handleMergedBlockMetaRequest() throws Exception {
        final RpcHandler.MergedBlockMetaReqHandler mergedBlockMetaReqHandler = (transportClient, mergedBlockMetaRequest, mergedBlockMetaResponseCallback) -> {
            if (mergedBlockMetaRequest.shuffleId == -1 || mergedBlockMetaRequest.reduceId == -1) {
                mergedBlockMetaResponseCallback.onFailure(new RuntimeException("empty block"));
            } else {
                mergedBlockMetaResponseCallback.onSuccess(2, (ManagedBuffer) Mockito.mock(ManagedBuffer.class));
            }
        };
        RpcHandler rpcHandler = new RpcHandler() { // from class: org.apache.spark.network.TransportRequestHandlerSuite.1
            public void receive(TransportClient transportClient2, ByteBuffer byteBuffer, RpcResponseCallback rpcResponseCallback) {
            }

            public StreamManager getStreamManager() {
                return null;
            }

            public RpcHandler.MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() {
                return mergedBlockMetaReqHandler;
            }
        };
        Channel channel = (Channel) Mockito.mock(Channel.class);
        ArrayList arrayList = new ArrayList();
        Mockito.when(channel.writeAndFlush(Mockito.any())).thenAnswer(invocationOnMock -> {
            Object obj = invocationOnMock.getArguments()[0];
            ExtendedChannelPromise extendedChannelPromise = new ExtendedChannelPromise(channel);
            arrayList.add(ImmutablePair.of(obj, extendedChannelPromise));
            return extendedChannelPromise;
        });
        TransportRequestHandler transportRequestHandler = new TransportRequestHandler(channel, (TransportClient) Mockito.mock(TransportClient.class), rpcHandler, 2L, (ChunkFetchRequestHandler) null);
        transportRequestHandler.handle(new MergedBlockMetaRequest(19L, "app1", 0, 0, 0));
        Assert.assertEquals(1L, arrayList.size());
        Assert.assertTrue(((Pair) arrayList.get(0)).getLeft() instanceof MergedBlockMetaSuccess);
        Assert.assertEquals(2L, ((MergedBlockMetaSuccess) ((Pair) arrayList.get(0)).getLeft()).getNumChunks());
        transportRequestHandler.handle(new MergedBlockMetaRequest(21L, "app1", -1, 0, 1));
        Assert.assertEquals(2L, arrayList.size());
        Assert.assertTrue(((Pair) arrayList.get(1)).getLeft() instanceof RpcFailure);
    }
}
