/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.dag.app.dag.impl;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.dag.StateChangeNotifier;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.app.dag.event.CallableEvent;
import org.apache.tez.dag.app.dag.event.VertexEventInputDataInformation;
import org.apache.tez.dag.app.dag.event.VertexEventRouteEvent;
import org.apache.tez.dag.app.dag.impl.CallableEventDispatcher;
import org.apache.tez.dag.app.dag.impl.RootInputVertexManager;
import org.apache.tez.dag.app.dag.impl.VertexManager;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.events.CustomProcessorEvent;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.api.impl.GroupInputSpec;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.verification.VerificationMode;

public class TestVertexManager {
    AppContext mockAppContext;
    ListeningExecutorService execService;
    Vertex mockVertex;
    EventHandler mockHandler;
    ArgumentCaptor<VertexEventInputDataInformation> requestCaptor;

    @Before
    public void setup() {
        this.mockAppContext = (AppContext)Mockito.mock(AppContext.class, (Answer)Mockito.RETURNS_DEEP_STUBS);
        this.execService = (ListeningExecutorService)Mockito.mock(ListeningExecutorService.class);
        final ListenableFuture mockFuture = (ListenableFuture)Mockito.mock(ListenableFuture.class);
        ((ListeningExecutorService)Mockito.doAnswer((Answer)new Answer(){

            public ListenableFuture<Void> answer(InvocationOnMock invocation) {
                Object[] args = invocation.getArguments();
                CallableEvent e = (CallableEvent)args[0];
                new CallableEventDispatcher().handle(e);
                return mockFuture;
            }
        }).when((Object)this.execService)).submit((Callable)Mockito.any());
        ((AppContext)Mockito.doReturn((Object)this.execService).when((Object)this.mockAppContext)).getExecService();
        this.mockVertex = (Vertex)Mockito.mock(Vertex.class, (Answer)Mockito.RETURNS_DEEP_STUBS);
        ((Vertex)Mockito.doReturn((Object)"vertex1").when((Object)this.mockVertex)).getName();
        this.mockHandler = (EventHandler)Mockito.mock(EventHandler.class);
        Mockito.when((Object)this.mockAppContext.getEventHandler()).thenReturn((Object)this.mockHandler);
        Mockito.when((Object)this.mockAppContext.getCurrentDAG().getVertex((String)Mockito.any(String.class)).getTotalTasks()).thenReturn((Object)1);
        this.requestCaptor = ArgumentCaptor.forClass(VertexEventInputDataInformation.class);
    }

    @Test(timeout=5000L)
    public void testVertexManagerPluginCtorAccessUserPayload() throws IOException, TezException {
        byte[] randomUserPayload = new byte[]{1, 2, 3};
        UserPayload userPayload = UserPayload.create((ByteBuffer)ByteBuffer.wrap(randomUserPayload));
        VertexManager vm = new VertexManager((VertexManagerPluginDescriptor)VertexManagerPluginDescriptor.create((String)CheckUserPayloadVertexManagerPlugin.class.getName()).setUserPayload(userPayload), UserGroupInformation.getCurrentUser(), this.mockVertex, this.mockAppContext, (StateChangeNotifier)Mockito.mock(StateChangeNotifier.class));
    }

    @Test(timeout=5000L)
    public void testOnRootVertexInitialized() throws Exception {
        Configuration conf = new Configuration();
        VertexManager vm = new VertexManager(RootInputVertexManager.createConfigBuilder((Configuration)conf).build(), UserGroupInformation.getCurrentUser(), this.mockVertex, this.mockAppContext, (StateChangeNotifier)Mockito.mock(StateChangeNotifier.class));
        vm.initialize();
        InputDescriptor id1 = (InputDescriptor)Mockito.mock(InputDescriptor.class);
        LinkedList<InputDataInformationEvent> events1 = new LinkedList<InputDataInformationEvent>();
        InputDataInformationEvent diEvent1 = InputDataInformationEvent.createWithSerializedPayload((int)0, null);
        events1.add(diEvent1);
        vm.onRootVertexInitialized("input1", id1, events1);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.times((int)1))).handle((Event)this.requestCaptor.capture());
        List tezEvents1 = ((VertexEventInputDataInformation)this.requestCaptor.getValue()).getEvents();
        Assert.assertEquals((long)1L, (long)tezEvents1.size());
        Assert.assertEquals((Object)diEvent1, (Object)((TezEvent)tezEvents1.get(0)).getEvent());
        InputDescriptor id2 = (InputDescriptor)Mockito.mock(InputDescriptor.class);
        LinkedList<InputDataInformationEvent> events2 = new LinkedList<InputDataInformationEvent>();
        InputDataInformationEvent diEvent2 = InputDataInformationEvent.createWithSerializedPayload((int)0, null);
        events2.add(diEvent2);
        vm.onRootVertexInitialized("input1", id2, events2);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.times((int)2))).handle((Event)this.requestCaptor.capture());
        List tezEvents2 = ((VertexEventInputDataInformation)this.requestCaptor.getValue()).getEvents();
        Assert.assertEquals((long)tezEvents2.size(), (long)1L);
        Assert.assertEquals((Object)diEvent2, (Object)((TezEvent)tezEvents2.get(0)).getEvent());
    }

    @Test(timeout=5000L)
    public void testOnRootVertexInitialized2() throws Exception {
        VertexManager vm = new VertexManager(VertexManagerPluginDescriptor.create((String)CustomVertexManager.class.getName()), UserGroupInformation.getCurrentUser(), this.mockVertex, this.mockAppContext, (StateChangeNotifier)Mockito.mock(StateChangeNotifier.class));
        vm.initialize();
        InputDescriptor id1 = (InputDescriptor)Mockito.mock(InputDescriptor.class);
        LinkedList<InputDataInformationEvent> events1 = new LinkedList<InputDataInformationEvent>();
        InputDataInformationEvent diEvent1 = InputDataInformationEvent.createWithSerializedPayload((int)0, null);
        events1.add(diEvent1);
        vm.onRootVertexInitialized("input1", id1, events1);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.times((int)1))).handle((Event)this.requestCaptor.capture());
        List tezEventsAfterInput1 = ((VertexEventInputDataInformation)this.requestCaptor.getValue()).getEvents();
        Assert.assertEquals((long)0L, (long)tezEventsAfterInput1.size());
        InputDescriptor id2 = (InputDescriptor)Mockito.mock(InputDescriptor.class);
        LinkedList<InputDataInformationEvent> events2 = new LinkedList<InputDataInformationEvent>();
        InputDataInformationEvent diEvent2 = InputDataInformationEvent.createWithSerializedPayload((int)0, null);
        events2.add(diEvent2);
        vm.onRootVertexInitialized("input2", id2, events2);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.times((int)2))).handle((Event)this.requestCaptor.capture());
        List tezEventsAfterInput2 = ((VertexEventInputDataInformation)this.requestCaptor.getValue()).getEvents();
        Assert.assertEquals((long)2L, (long)tezEventsAfterInput2.size());
        HashSet<String> edgeVertexSet = new HashSet<String>();
        for (TezEvent tezEvent : tezEventsAfterInput2) {
            edgeVertexSet.add(tezEvent.getDestinationInfo().getEdgeVertexName());
        }
        Assert.assertEquals((Object)Sets.newHashSet((Object[])new String[]{"input1", "input2"}), edgeVertexSet);
    }

    @Test(timeout=5000L)
    public void testVMPluginCtxGetInputVertexGroup() throws Exception {
        VertexManager vm = new VertexManager(VertexManagerPluginDescriptor.create((String)CustomVertexManager.class.getName()), UserGroupInformation.getCurrentUser(), this.mockVertex, this.mockAppContext, (StateChangeNotifier)Mockito.mock(StateChangeNotifier.class));
        Assert.assertTrue((boolean)vm.pluginContext.getInputVertexGroups().isEmpty());
        String group = "group";
        String v1 = "v1";
        String v2 = "v2";
        Mockito.when((Object)this.mockVertex.getGroupInputSpecList()).thenReturn(Arrays.asList(new GroupInputSpec(group, Arrays.asList(v1, v2), null)));
        Map groups = vm.pluginContext.getInputVertexGroups();
        Assert.assertEquals((long)1L, (long)groups.size());
        Assert.assertTrue((boolean)groups.containsKey(group));
        Assert.assertEquals((long)2L, (long)((List)groups.get(group)).size());
        Assert.assertTrue((boolean)((List)groups.get(group)).contains(v1));
        Assert.assertTrue((boolean)((List)groups.get(group)).contains(v2));
    }

    @Test(timeout=5000L)
    public void testSendCustomProcessorEvent() throws Exception {
        VertexManager vm = new VertexManager(VertexManagerPluginDescriptor.create((String)CustomVertexManager.class.getName()), UserGroupInformation.getCurrentUser(), this.mockVertex, this.mockAppContext, (StateChangeNotifier)Mockito.mock(StateChangeNotifier.class));
        ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(VertexEventRouteEvent.class);
        Mockito.when((Object)this.mockVertex.getTotalTasks()).thenReturn((Object)2);
        ArrayList<CustomProcessorEvent> events = new ArrayList<CustomProcessorEvent>();
        try {
            vm.pluginContext.sendEventToProcessor(events, -1);
            Assert.fail((String)"Should fail for invalid task id");
        }
        catch (IllegalArgumentException exception) {
            Assert.assertTrue((boolean)exception.getMessage().contains("Invalid taskId"));
        }
        try {
            vm.pluginContext.sendEventToProcessor(events, 10);
            Assert.fail((String)"Should fail for invalid task id");
        }
        catch (IllegalArgumentException exception) {
            Assert.assertTrue((boolean)exception.getMessage().contains("Invalid taskId"));
        }
        vm.pluginContext.sendEventToProcessor(null, 0);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.never())).handle((Event)requestCaptor.capture());
        vm.pluginContext.sendEventToProcessor(events, 1);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.never())).handle((Event)requestCaptor.capture());
        byte[] payload = new byte[]{1, 2, 3};
        events.add(CustomProcessorEvent.create((ByteBuffer)ByteBuffer.wrap(payload)));
        vm.pluginContext.sendEventToProcessor(events, 1);
        ((EventHandler)Mockito.verify((Object)this.mockHandler, (VerificationMode)Mockito.times((int)1))).handle((Event)requestCaptor.capture());
        CustomProcessorEvent cpe = (CustomProcessorEvent)((TezEvent)((VertexEventRouteEvent)requestCaptor.getValue()).getEvents().get(0)).getEvent();
        for (int i = 0; i < 2; ++i) {
            ByteBuffer payloadBuffer = cpe.getPayload();
            Assert.assertEquals((long)payload.length, (long)payloadBuffer.remaining());
            for (byte aPayload : payload) {
                Assert.assertEquals((long)aPayload, (long)payloadBuffer.get());
            }
        }
    }

    public static class CustomVertexManager
    extends VertexManagerPlugin {
        private Map<String, List<org.apache.tez.runtime.api.Event>> cachedEventMap = new HashMap<String, List<org.apache.tez.runtime.api.Event>>();

        public CustomVertexManager(VertexManagerPluginContext context) {
            super(context);
        }

        public void initialize() {
        }

        public void onVertexStarted(List<TaskAttemptIdentifier> completions) {
        }

        public void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
        }

        public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
        }

        public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<org.apache.tez.runtime.api.Event> events) {
            this.cachedEventMap.put(inputName, events);
            if (inputName.equals("input2")) {
                for (Map.Entry<String, List<org.apache.tez.runtime.api.Event>> entry : this.cachedEventMap.entrySet()) {
                    LinkedList riEvents = Lists.newLinkedList();
                    for (org.apache.tez.runtime.api.Event event : events) {
                        riEvents.add((InputDataInformationEvent)event);
                    }
                    this.getContext().addRootInputEvents(entry.getKey(), (Collection)riEvents);
                }
            }
        }
    }

    public static class CheckUserPayloadVertexManagerPlugin
    extends VertexManagerPlugin {
        public CheckUserPayloadVertexManagerPlugin(VertexManagerPluginContext context) {
            super(context);
            Assert.assertNotNull((Object)context.getUserPayload());
        }

        public void initialize() throws Exception {
        }

        public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) throws Exception {
        }

        public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<org.apache.tez.runtime.api.Event> events) throws Exception {
        }
    }
}

