package org.apache.spark.network.client;

import io.netty.channel.Channel;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.NettyUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/spark/network/client/TransportResponseHandler.class */
public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
    private final Channel channel;
    private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
    private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches = new ConcurrentHashMap();
    private final Map<Long, RpcResponseCallback> outstandingRpcs = new ConcurrentHashMap();
    private final AtomicLong timeOfLastRequestNs = new AtomicLong(0);

    public TransportResponseHandler(Channel channel) {
        this.channel = channel;
    }

    public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback chunkReceivedCallback) {
        this.timeOfLastRequestNs.set(System.nanoTime());
        this.outstandingFetches.put(streamChunkId, chunkReceivedCallback);
    }

    public void removeFetchRequest(StreamChunkId streamChunkId) {
        this.outstandingFetches.remove(streamChunkId);
    }

    public void addRpcRequest(long j, RpcResponseCallback rpcResponseCallback) {
        this.timeOfLastRequestNs.set(System.nanoTime());
        this.outstandingRpcs.put(Long.valueOf(j), rpcResponseCallback);
    }

    public void removeRpcRequest(long j) {
        this.outstandingRpcs.remove(Long.valueOf(j));
    }

    private void failOutstandingRequests(Throwable th) {
        for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : this.outstandingFetches.entrySet()) {
            entry.getValue().onFailure(entry.getKey().chunkIndex, th);
        }
        Iterator<Map.Entry<Long, RpcResponseCallback>> it = this.outstandingRpcs.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().onFailure(th);
        }
        this.outstandingFetches.clear();
        this.outstandingRpcs.clear();
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void channelUnregistered() {
        if (numOutstandingRequests() > 0) {
            String remoteAddress = NettyUtils.getRemoteAddress(this.channel);
            this.logger.error("Still have {} requests outstanding when connection from {} is closed", Integer.valueOf(numOutstandingRequests()), remoteAddress);
            failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
        }
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void exceptionCaught(Throwable th) {
        if (numOutstandingRequests() > 0) {
            this.logger.error("Still have {} requests outstanding when connection from {} is closed", Integer.valueOf(numOutstandingRequests()), NettyUtils.getRemoteAddress(this.channel));
            failOutstandingRequests(th);
        }
    }

    @Override // org.apache.spark.network.server.MessageHandler
    public void handle(ResponseMessage responseMessage) {
        String remoteAddress = NettyUtils.getRemoteAddress(this.channel);
        if (responseMessage instanceof ChunkFetchSuccess) {
            ChunkFetchSuccess chunkFetchSuccess = (ChunkFetchSuccess) responseMessage;
            ChunkReceivedCallback chunkReceivedCallback = this.outstandingFetches.get(chunkFetchSuccess.streamChunkId);
            if (chunkReceivedCallback == null) {
                this.logger.warn("Ignoring response for block {} from {} since it is not outstanding", chunkFetchSuccess.streamChunkId, remoteAddress);
                chunkFetchSuccess.buffer.release();
                return;
            } else {
                this.outstandingFetches.remove(chunkFetchSuccess.streamChunkId);
                chunkReceivedCallback.onSuccess(chunkFetchSuccess.streamChunkId.chunkIndex, chunkFetchSuccess.buffer);
                chunkFetchSuccess.buffer.release();
                return;
            }
        }
        if (responseMessage instanceof ChunkFetchFailure) {
            ChunkFetchFailure chunkFetchFailure = (ChunkFetchFailure) responseMessage;
            ChunkReceivedCallback chunkReceivedCallback2 = this.outstandingFetches.get(chunkFetchFailure.streamChunkId);
            if (chunkReceivedCallback2 == null) {
                this.logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", new Object[]{chunkFetchFailure.streamChunkId, remoteAddress, chunkFetchFailure.errorString});
                return;
            } else {
                this.outstandingFetches.remove(chunkFetchFailure.streamChunkId);
                chunkReceivedCallback2.onFailure(chunkFetchFailure.streamChunkId.chunkIndex, new ChunkFetchFailureException("Failure while fetching " + chunkFetchFailure.streamChunkId + ": " + chunkFetchFailure.errorString));
                return;
            }
        }
        if (responseMessage instanceof RpcResponse) {
            RpcResponse rpcResponse = (RpcResponse) responseMessage;
            RpcResponseCallback rpcResponseCallback = this.outstandingRpcs.get(Long.valueOf(rpcResponse.requestId));
            if (rpcResponseCallback == null) {
                this.logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", new Object[]{Long.valueOf(rpcResponse.requestId), remoteAddress, Integer.valueOf(rpcResponse.response.length)});
                return;
            } else {
                this.outstandingRpcs.remove(Long.valueOf(rpcResponse.requestId));
                rpcResponseCallback.onSuccess(rpcResponse.response);
                return;
            }
        }
        if (!(responseMessage instanceof RpcFailure)) {
            throw new IllegalStateException("Unknown response type: " + responseMessage.type());
        }
        RpcFailure rpcFailure = (RpcFailure) responseMessage;
        RpcResponseCallback rpcResponseCallback2 = this.outstandingRpcs.get(Long.valueOf(rpcFailure.requestId));
        if (rpcResponseCallback2 == null) {
            this.logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", new Object[]{Long.valueOf(rpcFailure.requestId), remoteAddress, rpcFailure.errorString});
        } else {
            this.outstandingRpcs.remove(Long.valueOf(rpcFailure.requestId));
            rpcResponseCallback2.onFailure(new RuntimeException(rpcFailure.errorString));
        }
    }

    public int numOutstandingRequests() {
        return this.outstandingFetches.size() + this.outstandingRpcs.size();
    }

    public long getTimeOfLastRequestNs() {
        return this.timeOfLastRequestNs.get();
    }
}
