/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.dag.app;

import java.io.IOException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.common.TezUtils;
import org.apache.tez.dag.api.NamedEntityDescriptor;
import org.apache.tez.dag.api.TezConstants;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.ContainerHeartbeatHandler;
import org.apache.tez.dag.app.TaskCommunicatorContextImpl;
import org.apache.tez.dag.app.TaskCommunicatorManager;
import org.apache.tez.dag.app.TaskHeartbeatHandler;
import org.apache.tez.dag.app.dag.DAG;
import org.apache.tez.dag.app.dag.event.DAGAppMasterEventType;
import org.apache.tez.dag.app.dag.event.DAGAppMasterEventUserServiceFatalError;
import org.apache.tez.dag.app.dag.event.DAGEventTerminateDag;
import org.apache.tez.dag.helpers.DagInfoImplForTest;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
import org.apache.tez.serviceplugins.api.DagInfo;
import org.apache.tez.serviceplugins.api.ServicePluginError;
import org.apache.tez.serviceplugins.api.ServicePluginErrorDefaults;
import org.apache.tez.serviceplugins.api.ServicePluginException;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicator;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;

public class TestTaskCommunicatorManager {
    private static final String DAG_NAME = "dagName";
    private static final int DAG_INDEX = 1;

    @Before
    @After
    public void resetForNextTest() {
        TaskCommManagerForMultipleCommTest.reset();
    }

    @Test(timeout=5000L)
    public void testNoTaskCommSpecified() throws IOException, TezException {
        AppContext appContext = (AppContext)Mockito.mock(AppContext.class);
        TaskHeartbeatHandler thh = (TaskHeartbeatHandler)Mockito.mock(TaskHeartbeatHandler.class);
        ContainerHeartbeatHandler chh = (ContainerHeartbeatHandler)Mockito.mock(ContainerHeartbeatHandler.class);
        try {
            new TaskCommManagerForMultipleCommTest(appContext, thh, chh, null);
            Assert.fail((String)"Initialization should have failed without a TaskComm specified");
        }
        catch (IllegalArgumentException illegalArgumentException) {
            // empty catch block
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test(timeout=5000L)
    public void testCustomTaskCommSpecified() throws IOException, TezException {
        AppContext appContext = (AppContext)Mockito.mock(AppContext.class);
        TaskHeartbeatHandler thh = (TaskHeartbeatHandler)Mockito.mock(TaskHeartbeatHandler.class);
        ContainerHeartbeatHandler chh = (ContainerHeartbeatHandler)Mockito.mock(ContainerHeartbeatHandler.class);
        String customTaskCommName = "customTaskComm";
        LinkedList<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<NamedEntityDescriptor>();
        ByteBuffer bb = ByteBuffer.allocate(4);
        bb.putInt(0, 3);
        UserPayload customPayload = UserPayload.create((ByteBuffer)bb);
        taskCommDescriptors.add(new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName()).setUserPayload(customPayload));
        TaskCommManagerForMultipleCommTest tcm = new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
        try {
            tcm.init(new Configuration(false));
            tcm.start();
            Assert.assertEquals((long)1L, (long)tcm.getNumTaskComms());
            Assert.assertFalse((boolean)tcm.getYarnTaskCommCreated());
            Assert.assertFalse((boolean)tcm.getUberTaskCommCreated());
            Assert.assertEquals((Object)customTaskCommName, (Object)TaskCommManagerForMultipleCommTest.getTaskCommName(0));
            Assert.assertEquals((Object)bb, (Object)TaskCommManagerForMultipleCommTest.getTaskCommContext(0).getInitialUserPayload().getPayload());
        }
        finally {
            tcm.stop();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test(timeout=5000L)
    public void testMultipleTaskComms() throws IOException, TezException {
        AppContext appContext = (AppContext)Mockito.mock(AppContext.class);
        TaskHeartbeatHandler thh = (TaskHeartbeatHandler)Mockito.mock(TaskHeartbeatHandler.class);
        ContainerHeartbeatHandler chh = (ContainerHeartbeatHandler)Mockito.mock(ContainerHeartbeatHandler.class);
        Configuration conf = new Configuration(false);
        conf.set("testkey", "testvalue");
        UserPayload defaultPayload = TezUtils.createUserPayloadFromConf((Configuration)conf);
        String customTaskCommName = "customTaskComm";
        LinkedList<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<NamedEntityDescriptor>();
        ByteBuffer bb = ByteBuffer.allocate(4);
        bb.putInt(0, 3);
        UserPayload customPayload = UserPayload.create((ByteBuffer)bb);
        taskCommDescriptors.add(new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName()).setUserPayload(customPayload));
        taskCommDescriptors.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload));
        TaskCommManagerForMultipleCommTest tcm = new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
        try {
            tcm.init(new Configuration(false));
            tcm.start();
            Assert.assertEquals((long)2L, (long)tcm.getNumTaskComms());
            Assert.assertTrue((boolean)tcm.getYarnTaskCommCreated());
            Assert.assertFalse((boolean)tcm.getUberTaskCommCreated());
            Assert.assertEquals((Object)customTaskCommName, (Object)TaskCommManagerForMultipleCommTest.getTaskCommName(0));
            Assert.assertEquals((Object)bb, (Object)TaskCommManagerForMultipleCommTest.getTaskCommContext(0).getInitialUserPayload().getPayload());
            Assert.assertEquals((Object)TezConstants.getTezYarnServicePluginName(), (Object)TaskCommManagerForMultipleCommTest.getTaskCommName(1));
            Configuration confParsed = TezUtils.createConfFromUserPayload((UserPayload)TaskCommManagerForMultipleCommTest.getTaskCommContext(1).getInitialUserPayload());
            Assert.assertEquals((Object)"testvalue", (Object)confParsed.get("testkey"));
        }
        finally {
            tcm.stop();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test(timeout=5000L)
    public void testEventRouting() throws Exception {
        AppContext appContext = (AppContext)Mockito.mock(AppContext.class, (Answer)Mockito.RETURNS_DEEP_STUBS);
        NodeId nodeId = NodeId.newInstance((String)"host1", (int)3131);
        Mockito.when((Object)appContext.getAllContainers().get((ContainerId)Mockito.any()).getContainer().getNodeId()).thenReturn((Object)nodeId);
        TaskHeartbeatHandler thh = (TaskHeartbeatHandler)Mockito.mock(TaskHeartbeatHandler.class);
        ContainerHeartbeatHandler chh = (ContainerHeartbeatHandler)Mockito.mock(ContainerHeartbeatHandler.class);
        Configuration conf = new Configuration(false);
        conf.set("testkey", "testvalue");
        UserPayload defaultPayload = TezUtils.createUserPayloadFromConf((Configuration)conf);
        String customTaskCommName = "customTaskComm";
        LinkedList<NamedEntityDescriptor> taskCommDescriptors = new LinkedList<NamedEntityDescriptor>();
        ByteBuffer bb = ByteBuffer.allocate(4);
        bb.putInt(0, 3);
        UserPayload customPayload = UserPayload.create((ByteBuffer)bb);
        taskCommDescriptors.add(new NamedEntityDescriptor(customTaskCommName, FakeTaskComm.class.getName()).setUserPayload(customPayload));
        taskCommDescriptors.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null).setUserPayload(defaultPayload));
        TaskCommManagerForMultipleCommTest tcm = new TaskCommManagerForMultipleCommTest(appContext, thh, chh, taskCommDescriptors);
        try {
            tcm.init(new Configuration(false));
            tcm.start();
            Assert.assertEquals((long)2L, (long)tcm.getNumTaskComms());
            Assert.assertTrue((boolean)tcm.getYarnTaskCommCreated());
            Assert.assertFalse((boolean)tcm.getUberTaskCommCreated());
            ((TaskCommunicator)Mockito.verify((Object)TaskCommManagerForMultipleCommTest.getTestTaskComm(0))).initialize();
            ((TaskCommunicator)Mockito.verify((Object)TaskCommManagerForMultipleCommTest.getTestTaskComm(0))).start();
            ((TaskCommunicator)Mockito.verify((Object)TaskCommManagerForMultipleCommTest.getTestTaskComm(1))).initialize();
            ((TaskCommunicator)Mockito.verify((Object)TaskCommManagerForMultipleCommTest.getTestTaskComm(1))).start();
            ContainerId containerId1 = (ContainerId)Mockito.mock(ContainerId.class);
            tcm.registerRunningContainer(containerId1, 0);
            ((TaskCommunicator)Mockito.verify((Object)TaskCommManagerForMultipleCommTest.getTestTaskComm(0))).registerRunningContainer((ContainerId)Mockito.eq((Object)containerId1), (String)Mockito.eq((Object)"host1"), Mockito.eq((int)3131));
            ContainerId containerId2 = (ContainerId)Mockito.mock(ContainerId.class);
            tcm.registerRunningContainer(containerId2, 1);
            ((TaskCommunicator)Mockito.verify((Object)TaskCommManagerForMultipleCommTest.getTestTaskComm(1))).registerRunningContainer((ContainerId)Mockito.eq((Object)containerId2), (String)Mockito.eq((Object)"host1"), Mockito.eq((int)3131));
        }
        finally {
            tcm.stop();
            ((TaskCommunicator)Mockito.verify((Object)tcm.getTaskCommunicator(0).getTaskCommunicator())).shutdown();
            ((TaskCommunicator)Mockito.verify((Object)tcm.getTaskCommunicator(1).getTaskCommunicator())).shutdown();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test(timeout=5000L)
    public void testReportFailureFromTaskCommunicator() throws TezException {
        String dagName = DAG_NAME;
        EventHandler eventHandler = (EventHandler)Mockito.mock(EventHandler.class);
        AppContext appContext = (AppContext)Mockito.mock(AppContext.class, (Answer)Mockito.RETURNS_DEEP_STUBS);
        ((AppContext)Mockito.doReturn((Object)"testTaskCommunicator").when((Object)appContext)).getTaskCommunicatorName(0);
        ((AppContext)Mockito.doReturn((Object)eventHandler).when((Object)appContext)).getEventHandler();
        DAG dag = (DAG)Mockito.mock(DAG.class);
        TezDAGID dagId = TezDAGID.getInstance((ApplicationId)ApplicationId.newInstance((long)1L, (int)0), (int)1);
        ((DAG)Mockito.doReturn((Object)dagName).when((Object)dag)).getName();
        ((DAG)Mockito.doReturn((Object)dagId).when((Object)dag)).getID();
        ((AppContext)Mockito.doReturn((Object)dag).when((Object)appContext)).getCurrentDAG();
        NamedEntityDescriptor namedEntityDescriptor = new NamedEntityDescriptor("testTaskCommunicator", TaskCommForFailureTest.class.getName());
        LinkedList<NamedEntityDescriptor> list = new LinkedList<NamedEntityDescriptor>();
        list.add(namedEntityDescriptor);
        TaskCommunicatorManager taskCommManager = new TaskCommunicatorManager(appContext, (TaskHeartbeatHandler)Mockito.mock(TaskHeartbeatHandler.class), (ContainerHeartbeatHandler)Mockito.mock(ContainerHeartbeatHandler.class), list);
        try {
            taskCommManager.init(new Configuration());
            taskCommManager.start();
            taskCommManager.registerRunningContainer((ContainerId)Mockito.mock(ContainerId.class), 0);
            ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Event.class);
            ((EventHandler)Mockito.verify((Object)eventHandler, (VerificationMode)Mockito.times((int)1))).handle((Event)argumentCaptor.capture());
            Event rawEvent = (Event)argumentCaptor.getValue();
            Assert.assertTrue((boolean)(rawEvent instanceof DAGEventTerminateDag));
            DAGEventTerminateDag killEvent = (DAGEventTerminateDag)rawEvent;
            Assert.assertTrue((boolean)killEvent.getDiagnosticInfo().contains("ReportError"));
            Assert.assertTrue((boolean)killEvent.getDiagnosticInfo().contains(ServicePluginErrorDefaults.SERVICE_UNAVAILABLE.name()));
            Assert.assertTrue((boolean)killEvent.getDiagnosticInfo().contains("[0:testTaskCommunicator]"));
            Mockito.reset((Object[])new EventHandler[]{eventHandler});
            taskCommManager.dagComplete(dag);
            argumentCaptor = ArgumentCaptor.forClass(Event.class);
            ((EventHandler)Mockito.verify((Object)eventHandler, (VerificationMode)Mockito.times((int)1))).handle((Event)argumentCaptor.capture());
            rawEvent = (Event)argumentCaptor.getValue();
            Assert.assertTrue((boolean)(rawEvent instanceof DAGAppMasterEventUserServiceFatalError));
            DAGAppMasterEventUserServiceFatalError event = (DAGAppMasterEventUserServiceFatalError)rawEvent;
            Assert.assertEquals((Object)DAGAppMasterEventType.TASK_COMMUNICATOR_SERVICE_FATAL_ERROR, (Object)event.getType());
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains("ReportedFatalError"));
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains(ServicePluginErrorDefaults.INCONSISTENT_STATE.name()));
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains("[0:testTaskCommunicator]"));
        }
        finally {
            taskCommManager.stop();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test(timeout=5000L)
    public void testTaskCommunicatorUserError() {
        TaskCommunicatorContextImpl taskCommContext = (TaskCommunicatorContextImpl)Mockito.mock(TaskCommunicatorContextImpl.class);
        TaskCommunicator taskCommunicator = (TaskCommunicator)Mockito.mock(TaskCommunicator.class, (Answer)new ExceptionAnswer());
        ((TaskCommunicator)Mockito.doReturn((Object)taskCommContext).when((Object)taskCommunicator)).getContext();
        EventHandler eventHandler = (EventHandler)Mockito.mock(EventHandler.class);
        AppContext appContext = (AppContext)Mockito.mock(AppContext.class, (Answer)Mockito.RETURNS_DEEP_STUBS);
        Mockito.when((Object)appContext.getEventHandler()).thenReturn((Object)eventHandler);
        ((AppContext)Mockito.doReturn((Object)"testTaskCommunicator").when((Object)appContext)).getTaskCommunicatorName(0);
        String expectedId = "[0:testTaskCommunicator]";
        Configuration conf = new Configuration(false);
        TaskCommunicatorManager taskCommunicatorManager = new TaskCommunicatorManager(taskCommunicator, appContext, (TaskHeartbeatHandler)Mockito.mock(TaskHeartbeatHandler.class), (ContainerHeartbeatHandler)Mockito.mock(ContainerHeartbeatHandler.class));
        try {
            taskCommunicatorManager.init(conf);
            taskCommunicatorManager.start();
            DAG mockDag = (DAG)Mockito.mock(DAG.class, (Answer)Mockito.RETURNS_DEEP_STUBS);
            Mockito.when((Object)mockDag.getID().getId()).thenReturn((Object)1);
            taskCommunicatorManager.dagComplete(mockDag);
            ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Event.class);
            ((EventHandler)Mockito.verify((Object)eventHandler, (VerificationMode)Mockito.times((int)1))).handle((Event)argumentCaptor.capture());
            Event rawEvent = (Event)argumentCaptor.getValue();
            Assert.assertTrue((boolean)(rawEvent instanceof DAGAppMasterEventUserServiceFatalError));
            DAGAppMasterEventUserServiceFatalError event = (DAGAppMasterEventUserServiceFatalError)rawEvent;
            Assert.assertEquals((Object)DAGAppMasterEventType.TASK_COMMUNICATOR_SERVICE_FATAL_ERROR, (Object)event.getType());
            Assert.assertTrue((boolean)event.getError().getMessage().contains("TestException_dagComplete"));
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains("DAG completion"));
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains(expectedId));
            Mockito.when((Object)appContext.getAllContainers().get((ContainerId)Mockito.any()).getContainer().getNodeId()).thenReturn((Object)((NodeId)Mockito.mock(NodeId.class)));
            taskCommunicatorManager.registerRunningContainer((ContainerId)Mockito.mock(ContainerId.class), 0);
            argumentCaptor = ArgumentCaptor.forClass(Event.class);
            ((EventHandler)Mockito.verify((Object)eventHandler, (VerificationMode)Mockito.times((int)2))).handle((Event)argumentCaptor.capture());
            rawEvent = (Event)argumentCaptor.getAllValues().get(1);
            Assert.assertTrue((boolean)(rawEvent instanceof DAGAppMasterEventUserServiceFatalError));
            event = (DAGAppMasterEventUserServiceFatalError)rawEvent;
            Assert.assertEquals((Object)DAGAppMasterEventType.TASK_COMMUNICATOR_SERVICE_FATAL_ERROR, (Object)event.getType());
            Assert.assertTrue((boolean)event.getError().getMessage().contains("TestException_registerRunningContainer"));
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains("registering running Container"));
            Assert.assertTrue((boolean)event.getDiagnosticInfo().contains(expectedId));
        }
        finally {
            taskCommunicatorManager.stop();
        }
    }

    public static class TaskCommForFailureTest
    extends TaskCommunicator {
        public TaskCommForFailureTest(TaskCommunicatorContext taskCommunicatorContext) {
            super(taskCommunicatorContext);
        }

        public void registerRunningContainer(ContainerId containerId, String hostname, int port) throws ServicePluginException {
            this.getContext().reportError((ServicePluginError)ServicePluginErrorDefaults.SERVICE_UNAVAILABLE, "ReportError", (DagInfo)new DagInfoImplForTest(1, TestTaskCommunicatorManager.DAG_NAME));
        }

        public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason, @Nullable String diagnostics) throws ServicePluginException {
        }

        public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, Map<String, LocalResource> additionalResources, Credentials credentials, boolean credentialsChanged, int priority) throws ServicePluginException {
        }

        public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID, TaskAttemptEndReason endReason, @Nullable String diagnostics) throws ServicePluginException {
        }

        public InetSocketAddress getAddress() throws ServicePluginException {
            return null;
        }

        public void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws ServicePluginException {
        }

        public void dagComplete(int dagIdentifier) throws ServicePluginException {
            this.getContext().reportError((ServicePluginError)ServicePluginErrorDefaults.INCONSISTENT_STATE, "ReportedFatalError", null);
        }

        public Object getMetaInfo() throws ServicePluginException {
            return null;
        }
    }

    public static class FakeTaskComm
    extends TaskCommunicator {
        public FakeTaskComm(TaskCommunicatorContext taskCommunicatorContext) {
            super(taskCommunicatorContext);
        }

        public void registerRunningContainer(ContainerId containerId, String hostname, int port) {
        }

        public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason, String diagnostics) {
        }

        public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, Map<String, LocalResource> additionalResources, Credentials credentials, boolean credentialsChanged, int priority) {
        }

        public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID, TaskAttemptEndReason endReason, String diagnostics) {
        }

        public InetSocketAddress getAddress() {
            return null;
        }

        public void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
        }

        public void dagComplete(int dagIdentifier) {
        }

        public Object getMetaInfo() {
            return null;
        }
    }

    static class TaskCommManagerForMultipleCommTest
    extends TaskCommunicatorManager {
        private static final AtomicInteger numTaskComms = new AtomicInteger(0);
        private static final Set<Integer> taskCommIndices = new HashSet<Integer>();
        private static final TaskCommunicator yarnTaskComm = (TaskCommunicator)Mockito.mock(TaskCommunicator.class);
        private static final TaskCommunicator uberTaskComm = (TaskCommunicator)Mockito.mock(TaskCommunicator.class);
        private static final AtomicBoolean yarnTaskCommCreated = new AtomicBoolean(false);
        private static final AtomicBoolean uberTaskCommCreated = new AtomicBoolean(false);
        private static final List<TaskCommunicatorContext> taskCommContexts = new LinkedList<TaskCommunicatorContext>();
        private static final List<String> taskCommNames = new LinkedList<String>();
        private static final List<TaskCommunicator> testTaskComms = new LinkedList<TaskCommunicator>();

        public static void reset() {
            numTaskComms.set(0);
            taskCommIndices.clear();
            yarnTaskCommCreated.set(false);
            uberTaskCommCreated.set(false);
            taskCommContexts.clear();
            taskCommNames.clear();
            testTaskComms.clear();
        }

        public TaskCommManagerForMultipleCommTest(AppContext context, TaskHeartbeatHandler thh, ContainerHeartbeatHandler chh, List<NamedEntityDescriptor> taskCommunicatorDescriptors) throws TezException {
            super(context, thh, chh, taskCommunicatorDescriptors);
        }

        TaskCommunicator createTaskCommunicator(NamedEntityDescriptor taskCommDescriptor, int taskCommIndex) throws TezException {
            numTaskComms.incrementAndGet();
            boolean added = taskCommIndices.add(taskCommIndex);
            Assert.assertTrue((String)"Cannot add multiple taskComms with the same index", (boolean)added);
            taskCommNames.add(taskCommDescriptor.getEntityName());
            return super.createTaskCommunicator(taskCommDescriptor, taskCommIndex);
        }

        TaskCommunicator createDefaultTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
            taskCommContexts.add(taskCommunicatorContext);
            yarnTaskCommCreated.set(true);
            testTaskComms.add(yarnTaskComm);
            return yarnTaskComm;
        }

        TaskCommunicator createUberTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext) {
            taskCommContexts.add(taskCommunicatorContext);
            uberTaskCommCreated.set(true);
            testTaskComms.add(uberTaskComm);
            return uberTaskComm;
        }

        TaskCommunicator createCustomTaskCommunicator(TaskCommunicatorContext taskCommunicatorContext, NamedEntityDescriptor taskCommDescriptor) throws TezException {
            taskCommContexts.add(taskCommunicatorContext);
            TaskCommunicator spyComm = (TaskCommunicator)Mockito.spy((Object)super.createCustomTaskCommunicator(taskCommunicatorContext, taskCommDescriptor));
            testTaskComms.add(spyComm);
            return spyComm;
        }

        public static int getNumTaskComms() {
            return numTaskComms.get();
        }

        public static boolean getYarnTaskCommCreated() {
            return yarnTaskCommCreated.get();
        }

        public static boolean getUberTaskCommCreated() {
            return uberTaskCommCreated.get();
        }

        public static TaskCommunicatorContext getTaskCommContext(int taskCommIndex) {
            return taskCommContexts.get(taskCommIndex);
        }

        public static String getTaskCommName(int taskCommIndex) {
            return taskCommNames.get(taskCommIndex);
        }

        public static TaskCommunicator getTestTaskComm(int taskCommIndex) {
            return testTaskComms.get(taskCommIndex);
        }
    }

    private static class ExceptionAnswer
    implements Answer {
        private ExceptionAnswer() {
        }

        public Object answer(InvocationOnMock invocation) throws Throwable {
            Method method = invocation.getMethod();
            if (!(!method.getDeclaringClass().equals(TaskCommunicator.class) || method.getName().equals("getContext") || method.getName().equals("initialize") || method.getName().equals("start") || method.getName().equals("shutdown"))) {
                throw new RuntimeException("TestException_" + method.getName());
            }
            return invocation.callRealMethod();
        }
    }
}

