package org.apache.tez.dag.app;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.dag.DAG;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.app.dag.VertexStateUpdateListener;
import org.apache.tez.dag.app.rm.container.AMContainer;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.TaskFailureType;
import org.apache.tez.serviceplugins.api.DagInfo;
import org.apache.tez.serviceplugins.api.ServicePluginError;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.apache.tez.serviceplugins.api.TaskHeartbeatRequest;
import org.apache.tez.serviceplugins.api.TaskHeartbeatResponse;

@InterfaceAudience.Private
/* loaded from: input_file:org/apache/tez/dag/app/TaskCommunicatorContextImpl.class */
public class TaskCommunicatorContextImpl implements TaskCommunicatorContext, VertexStateUpdateListener {
    private final AppContext context;
    private final TaskCommunicatorManager taskCommunicatorManager;
    private final int taskCommunicatorIndex;
    private final ReentrantReadWriteLock.ReadLock dagChangedReadLock;
    private final ReentrantReadWriteLock.WriteLock dagChangedWriteLock;
    private final UserPayload userPayload;
    private DAG dag;

    public TaskCommunicatorContextImpl(AppContext appContext, TaskCommunicatorManager taskCommunicatorManager, UserPayload userPayload, int i) {
        this.context = appContext;
        this.taskCommunicatorManager = taskCommunicatorManager;
        this.userPayload = userPayload;
        this.taskCommunicatorIndex = i;
        ReentrantReadWriteLock reentrantReadWriteLock = new ReentrantReadWriteLock();
        this.dagChangedReadLock = reentrantReadWriteLock.readLock();
        this.dagChangedWriteLock = reentrantReadWriteLock.writeLock();
    }

    public UserPayload getInitialUserPayload() {
        return this.userPayload;
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public ApplicationAttemptId getApplicationAttemptId() {
        return this.context.getApplicationAttemptId();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public Credentials getAMCredentials() {
        return this.context.getAppCredentials();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public boolean canCommit(TezTaskAttemptID tezTaskAttemptID) throws IOException {
        return this.taskCommunicatorManager.canCommit(tezTaskAttemptID);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public TaskHeartbeatResponse heartbeat(TaskHeartbeatRequest taskHeartbeatRequest) throws IOException, TezException {
        return this.taskCommunicatorManager.heartbeat(taskHeartbeatRequest);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public boolean isKnownContainer(ContainerId containerId) {
        AMContainer aMContainer = this.context.getAllContainers().get(containerId);
        return aMContainer != null && aMContainer.getTaskCommunicatorIdentifier() == this.taskCommunicatorIndex;
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public void taskAlive(TezTaskAttemptID tezTaskAttemptID) {
        this.taskCommunicatorManager.taskAlive(tezTaskAttemptID);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public void containerAlive(ContainerId containerId) {
        if (isKnownContainer(containerId)) {
            this.taskCommunicatorManager.containerAlive(containerId);
        }
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public void taskStartedRemotely(TezTaskAttemptID tezTaskAttemptID, ContainerId containerId) {
        this.taskCommunicatorManager.taskStartedRemotely(tezTaskAttemptID, containerId);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public void taskKilled(TezTaskAttemptID tezTaskAttemptID, TaskAttemptEndReason taskAttemptEndReason, @Nullable String str) {
        this.taskCommunicatorManager.taskKilled(tezTaskAttemptID, taskAttemptEndReason, str);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public void taskFailed(TezTaskAttemptID tezTaskAttemptID, TaskFailureType taskFailureType, TaskAttemptEndReason taskAttemptEndReason, @Nullable String str) {
        this.taskCommunicatorManager.taskFailed(tezTaskAttemptID, taskFailureType, taskAttemptEndReason, str);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public void registerForVertexStateUpdates(String str, @Nullable Set<VertexState> set) {
        Preconditions.checkNotNull(str, "VertexName cannot be null: " + str);
        getDag().getStateChangeNotifier().registerForVertexUpdates(str, set, this);
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public String getCurrentAppIdentifier() {
        return this.context.getApplicationID().toString();
    }

    @Nullable
    public DagInfo getCurrentDagInfo() {
        return getDag();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public Iterable<String> getInputVertexNames(String str) {
        Preconditions.checkNotNull(str, "VertexName cannot be null: " + str);
        return Iterables.transform(getDag().getVertex(str).getInputVertices().keySet(), new Function<Vertex, String>() { // from class: org.apache.tez.dag.app.TaskCommunicatorContextImpl.1
            public String apply(Vertex vertex) {
                return vertex.getName();
            }
        });
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public int getVertexTotalTaskCount(String str) {
        Preconditions.checkArgument(str != null, "VertexName must be specified");
        return getDag().getVertex(str).getTotalTasks();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public int getVertexCompletedTaskCount(String str) {
        Preconditions.checkArgument(str != null, "VertexName must be specified");
        return getDag().getVertex(str).getCompletedTasks();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public int getVertexRunningTaskCount(String str) {
        Preconditions.checkArgument(str != null, "VertexName must be specified");
        return getDag().getVertex(str).getRunningTasks();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public long getFirstAttemptStartTime(String str, int i) {
        Preconditions.checkArgument(str != null, "VertexName must be specified");
        Preconditions.checkArgument(i >= 0, "TaskIndex must be > 0");
        return getDag().getVertex(str).getTask(i).getFirstAttemptStartTime();
    }

    @Override // org.apache.tez.serviceplugins.api.TaskCommunicatorContext
    public long getDagStartTime() {
        return getDag().getStartTime();
    }

    public void reportError(@Nonnull ServicePluginError servicePluginError, String str, DagInfo dagInfo) {
        Preconditions.checkNotNull(servicePluginError, "ServicePluginError must be set");
        this.taskCommunicatorManager.reportError(this.taskCommunicatorIndex, servicePluginError, str, dagInfo);
    }

    @Override // org.apache.tez.dag.app.dag.VertexStateUpdateListener
    public void onStateUpdated(VertexStateUpdate vertexStateUpdate) {
        this.taskCommunicatorManager.vertexStateUpdateNotificationReceived(vertexStateUpdate, this.taskCommunicatorIndex);
    }

    private DAG getDag() {
        this.dagChangedReadLock.lock();
        try {
            return this.dag != null ? this.dag : this.context.getCurrentDAG();
        } finally {
            this.dagChangedReadLock.unlock();
        }
    }

    @InterfaceAudience.Private
    public void dagCompleteStart(DAG dag) {
        this.dagChangedWriteLock.lock();
        try {
            this.dag = dag;
        } finally {
            this.dagChangedWriteLock.unlock();
        }
    }

    public void dagCompleteEnd() {
        this.dagChangedWriteLock.lock();
        try {
            this.dag = null;
        } finally {
            this.dagChangedWriteLock.unlock();
        }
    }
}
