package org.apache.tez.runtime.library.cartesianproduct;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezReflectionException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.records.TaskAttemptIdentifierImpl;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

/* loaded from: input_file:org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.class */
public class TestCartesianProductVertexManagerPartitioned {

    @Captor
    private ArgumentCaptor<Map<String, EdgeProperty>> edgePropertiesCaptor;

    @Captor
    private ArgumentCaptor<List<VertexManagerPluginContext.ScheduleTaskRequest>> scheduleTaskRequestCaptor;
    private CartesianProductVertexManagerPartitioned vertexManager;
    private VertexManagerPluginContext context;
    private List<TaskAttemptIdentifier> allCompletions;

    /* loaded from: input_file:org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned$TestFilter.class */
    public static class TestFilter extends CartesianProductFilter {
        public TestFilter(UserPayload userPayload) {
            super(userPayload);
        }

        public boolean isValidCombination(Map<String, Integer> map) {
            return map.get("v0").intValue() > map.get("v1").intValue();
        }
    }

    @Before
    public void setup() throws TezReflectionException {
        CartesianProductUserPayload.CartesianProductConfigProto.Builder newBuilder = CartesianProductUserPayload.CartesianProductConfigProto.newBuilder();
        newBuilder.setIsPartitioned(true).addSources("v0").addSources("v1").addNumPartitions(2).addNumPartitions(2);
        setupWithConfig(newBuilder.build());
    }

    private void setupWithConfig(CartesianProductUserPayload.CartesianProductConfigProto cartesianProductConfigProto) throws TezReflectionException {
        MockitoAnnotations.openMocks(this);
        this.context = (VertexManagerPluginContext) Mockito.mock(VertexManagerPluginContext.class);
        Mockito.when(this.context.getVertexName()).thenReturn("cp");
        Mockito.when(Integer.valueOf(this.context.getVertexNumTasks("cp"))).thenReturn(-1);
        this.vertexManager = new CartesianProductVertexManagerPartitioned(this.context);
        HashMap hashMap = new HashMap();
        hashMap.put("v0", EdgeProperty.create(EdgeManagerPluginDescriptor.create(CartesianProductEdgeManager.class.getName()), (EdgeProperty.DataSourceType) null, (EdgeProperty.SchedulingType) null, (OutputDescriptor) null, (InputDescriptor) null));
        hashMap.put("v1", EdgeProperty.create(EdgeManagerPluginDescriptor.create(CartesianProductEdgeManager.class.getName()), (EdgeProperty.DataSourceType) null, (EdgeProperty.SchedulingType) null, (OutputDescriptor) null, (InputDescriptor) null));
        hashMap.put("v2", EdgeProperty.create(EdgeProperty.DataMovementType.BROADCAST, (EdgeProperty.DataSourceType) null, (EdgeProperty.SchedulingType) null, (OutputDescriptor) null, (InputDescriptor) null));
        Mockito.when(this.context.getInputVertexEdgeProperties()).thenReturn(hashMap);
        Mockito.when(Integer.valueOf(this.context.getVertexNumTasks((String) Mockito.eq("v0")))).thenReturn(4);
        Mockito.when(Integer.valueOf(this.context.getVertexNumTasks((String) Mockito.eq("v1")))).thenReturn(4);
        Mockito.when(Integer.valueOf(this.context.getVertexNumTasks((String) Mockito.eq("v2")))).thenReturn(4);
        this.vertexManager.initialize(cartesianProductConfigProto);
        this.allCompletions = new ArrayList();
        for (int i = 0; i < 3; i++) {
            for (int i2 = 0; i2 < 4; i2++) {
                this.allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v" + i, TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(TezDAGID.getInstance("0", 0, 0), i), i2), 0)));
            }
        }
    }

    private void testReconfigureVertexHelper(CartesianProductUserPayload.CartesianProductConfigProto cartesianProductConfigProto, int i) throws Exception {
        setupWithConfig(cartesianProductConfigProto);
        ArgumentCaptor forClass = ArgumentCaptor.forClass(Integer.class);
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
        ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.times(1))).reconfigureVertex(((Integer) forClass.capture()).intValue(), (VertexLocationHint) Mockito.isNull(), (Map) this.edgePropertiesCaptor.capture());
        Assert.assertEquals(((Integer) forClass.getValue()).intValue(), i);
        Assert.assertNull(this.edgePropertiesCaptor.getValue());
    }

    @Test(timeout = 5000)
    public void testReconfigureVertex() throws Exception {
        CartesianProductUserPayload.CartesianProductConfigProto.Builder newBuilder = CartesianProductUserPayload.CartesianProductConfigProto.newBuilder();
        newBuilder.setIsPartitioned(true).addSources("v0").addSources("v1").addNumPartitions(5).addNumPartitions(5).setFilterClassName(TestFilter.class.getName());
        testReconfigureVertexHelper(newBuilder.build(), 10);
        newBuilder.clearFilterClassName();
        testReconfigureVertexHelper(newBuilder.build(), 25);
    }

    @Test(timeout = 5000)
    public void testScheduling() throws Exception {
        this.vertexManager.onVertexStarted((List) null);
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(0));
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(1));
        ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.never())).scheduleTasks((List) Mockito.any());
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(2));
        ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.never())).scheduleTasks((List) Mockito.any());
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
        ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.times(1))).scheduleTasks((List) this.scheduleTaskRequestCaptor.capture());
        List list = (List) this.scheduleTaskRequestCaptor.getValue();
        Assert.assertEquals(1L, list.size());
        Assert.assertEquals(0L, ((VertexManagerPluginContext.ScheduleTaskRequest) list.get(0)).getTaskIndex());
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(8));
        ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.times(1))).scheduleTasks((List) this.scheduleTaskRequestCaptor.capture());
        for (int i = 3; i < 6; i++) {
            this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(i));
            ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.times(i - 1))).scheduleTasks((List) this.scheduleTaskRequestCaptor.capture());
            List list2 = (List) this.scheduleTaskRequestCaptor.getValue();
            Assert.assertEquals(1L, list2.size());
            Assert.assertEquals(i - 2, ((VertexManagerPluginContext.ScheduleTaskRequest) list2.get(0)).getTaskIndex());
        }
        for (int i2 = 6; i2 < 8; i2++) {
            this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(i2));
            ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.times(4))).scheduleTasks((List) Mockito.any());
        }
    }

    @Test(timeout = 5000)
    public void testOnVertexStartWithBroadcastRunning() throws Exception {
        testOnVertexStartHelper(true);
    }

    @Test(timeout = 5000)
    public void testOnVertexStartWithoutBroadcastRunning() throws Exception {
        testOnVertexStartHelper(false);
    }

    private void testOnVertexStartHelper(boolean z) throws Exception {
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
        if (z) {
            this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.allCompletions.get(0));
        arrayList.add(this.allCompletions.get(1));
        arrayList.add(this.allCompletions.get(4));
        arrayList.add(this.allCompletions.get(8));
        this.vertexManager.onVertexStarted(arrayList);
        if (!z) {
            ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.never())).scheduleTasks((List) Mockito.any());
            this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
        }
        ((VertexManagerPluginContext) Mockito.verify(this.context, Mockito.times(1))).scheduleTasks((List) this.scheduleTaskRequestCaptor.capture());
        List list = (List) this.scheduleTaskRequestCaptor.getValue();
        Assert.assertEquals(1L, list.size());
        Assert.assertEquals(0L, ((VertexManagerPluginContext.ScheduleTaskRequest) list.get(0)).getTaskIndex());
    }
}
