package org.apache.drill.exec.planner.rm;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.drill.PlanTestBase;
import org.apache.drill.exec.ops.QueryContext;
import org.apache.drill.exec.planner.cost.NodeResource;
import org.apache.drill.exec.planner.fragment.Fragment;
import org.apache.drill.exec.planner.fragment.PlanningSet;
import org.apache.drill.exec.planner.fragment.QueueQueryParallelizer;
import org.apache.drill.exec.planner.fragment.SimpleParallelizer;
import org.apache.drill.exec.planner.fragment.Wrapper;
import org.apache.drill.exec.pop.PopUnitTestBase;
import org.apache.drill.exec.proto.CoordinationProtos;
import org.apache.drill.exec.proto.UserBitShared;
import org.apache.drill.exec.proto.UserProtos;
import org.apache.drill.exec.rpc.user.UserSession;
import org.apache.drill.exec.server.DrillbitContext;
import org.apache.drill.exec.work.foreman.rm.EmbeddedQueryQueue;
import org.apache.drill.shaded.guava.com.google.common.collect.Iterables;
import org.apache.drill.test.ClientFixture;
import org.apache.drill.test.ClusterFixture;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/drill/exec/planner/rm/TestMemoryCalculator.class */
public class TestMemoryCalculator extends PlanTestBase {
    private static final long DEFAULT_SLICE_TARGET = 100000;
    private static final long DEFAULT_BATCH_SIZE = 16777216;
    private static final UserSession session = UserSession.Builder.newBuilder().withCredentials(UserBitShared.UserCredentials.newBuilder().setUserName("foo").build()).withUserProperties(UserProtos.UserProperties.getDefaultInstance()).withOptionManager(bits[0].getContext().getOptionManager()).build();
    private static final CoordinationProtos.DrillbitEndpoint N1_EP1 = newDrillbitEndpoint("node1", 30010);
    private static final CoordinationProtos.DrillbitEndpoint N1_EP2 = newDrillbitEndpoint("node2", 30011);
    private static final CoordinationProtos.DrillbitEndpoint N1_EP3 = newDrillbitEndpoint("node3", 30012);
    private static final CoordinationProtos.DrillbitEndpoint N1_EP4 = newDrillbitEndpoint("node4", 30013);
    private static final CoordinationProtos.DrillbitEndpoint[] nodeList = {N1_EP1, N1_EP2, N1_EP3, N1_EP4};
    private static final DrillbitContext drillbitContext = getDrillbitContext();
    private static final QueryContext queryContext = new QueryContext(session, drillbitContext, UserBitShared.QueryId.getDefaultInstance());

    private static final CoordinationProtos.DrillbitEndpoint newDrillbitEndpoint(String str, int i) {
        return CoordinationProtos.DrillbitEndpoint.newBuilder().setAddress(str).setControlPort(i).build();
    }

    @AfterClass
    public static void close() throws Exception {
        queryContext.close();
    }

    private final Wrapper mockWrapper(Wrapper wrapper, Map<CoordinationProtos.DrillbitEndpoint, NodeResource> map, List<CoordinationProtos.DrillbitEndpoint> list, Map<Fragment, Wrapper> map2) {
        Wrapper wrapper2 = (Wrapper) Mockito.mock(Wrapper.class);
        map2.put(wrapper.getNode(), wrapper2);
        ArrayList arrayList = new ArrayList();
        Iterator it = wrapper.getFragmentDependencies().iterator();
        while (it.hasNext()) {
            arrayList.add(mockWrapper((Wrapper) it.next(), map, list, map2));
        }
        Mockito.when(wrapper2.getNode()).thenReturn(wrapper.getNode());
        Mockito.when(wrapper2.getAssignedEndpoints()).thenReturn(list);
        Mockito.when(wrapper2.getResourceMap()).thenReturn(map);
        Mockito.when(Integer.valueOf(wrapper2.getWidth())).thenReturn(Integer.valueOf(list.size()));
        Mockito.when(wrapper2.getFragmentDependencies()).thenReturn(arrayList);
        Mockito.when(Boolean.valueOf(wrapper2.isEndpointsAssignmentDone())).thenReturn(true);
        return wrapper2;
    }

    private final PlanningSet mockPlanningSet(PlanningSet planningSet, Map<CoordinationProtos.DrillbitEndpoint, NodeResource> map, List<CoordinationProtos.DrillbitEndpoint> list) {
        HashMap hashMap = new HashMap();
        Wrapper mockWrapper = mockWrapper(planningSet.getRootWrapper(), map, list, hashMap);
        PlanningSet planningSet2 = (PlanningSet) Mockito.mock(PlanningSet.class);
        Mockito.when(planningSet2.getRootWrapper()).thenReturn(mockWrapper);
        Mockito.when(planningSet2.get((Fragment) ArgumentMatchers.any(Fragment.class))).thenAnswer(invocationOnMock -> {
            return hashMap.get(invocationOnMock.getArgument(0));
        });
        return planningSet2;
    }

    private String getPlanForQuery(String str) throws Exception {
        return getPlanForQuery(str, DEFAULT_BATCH_SIZE);
    }

    private String getPlanForQuery(String str, long j) throws Exception {
        return getPlanForQuery(str, j, DEFAULT_SLICE_TARGET);
    }

    private String getPlanForQuery(String str, long j, long j2) throws Exception {
        ClusterFixture build = ClusterFixture.builder(dirTestWatcher).setOptionDefault("drill.exec.memory.operator.output_batch_size", Long.valueOf(j)).setOptionDefault("planner.slice_target", Long.valueOf(j2)).build();
        Throwable th = null;
        try {
            ClientFixture clientFixture = build.clientFixture();
            Throwable th2 = null;
            try {
                try {
                    String explainJson = clientFixture.queryBuilder().sql(str).explainJson();
                    if (clientFixture != null) {
                        if (0 != 0) {
                            try {
                                clientFixture.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            clientFixture.close();
                        }
                    }
                    return explainJson;
                } finally {
                }
            } catch (Throwable th4) {
                if (clientFixture != null) {
                    if (th2 != null) {
                        try {
                            clientFixture.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        clientFixture.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (build != null) {
                if (0 != 0) {
                    try {
                        build.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    build.close();
                }
            }
        }
    }

    private List<CoordinationProtos.DrillbitEndpoint> getEndpoints(int i, Set<CoordinationProtos.DrillbitEndpoint> set) {
        ArrayList arrayList = new ArrayList();
        Iterator it = Iterables.cycle(nodeList).iterator();
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return arrayList;
            }
            CoordinationProtos.DrillbitEndpoint drillbitEndpoint = (CoordinationProtos.DrillbitEndpoint) it.next();
            if (!set.contains(drillbitEndpoint)) {
                arrayList.add(drillbitEndpoint);
            }
        }
    }

    private Set<Wrapper> createSet(Wrapper... wrapperArr) {
        HashSet hashSet = new HashSet();
        for (Wrapper wrapper : wrapperArr) {
            hashSet.add(wrapper);
        }
        return hashSet;
    }

    private Fragment getRootFragmentFromPlan(DrillbitContext drillbitContext2, String str) throws Exception {
        return PopUnitTestBase.getRootFragmentFromPlanString(drillbitContext2.getPlanReader(), str);
    }

    private PlanningSet preparePlanningSet(List<CoordinationProtos.DrillbitEndpoint> list, long j, Map<CoordinationProtos.DrillbitEndpoint, NodeResource> map, String str, SimpleParallelizer simpleParallelizer) throws Exception {
        return mockPlanningSet(simpleParallelizer.prepareFragmentTree(getRootFragmentFromPlan(drillbitContext, getPlanForQuery(str, 10L, j))), map, list);
    }

    @Test
    public void TestSingleMajorFragmentWithProjectAndScan() throws Exception {
        List<CoordinationProtos.DrillbitEndpoint> endpoints = getEndpoints(2, new HashSet());
        Map<CoordinationProtos.DrillbitEndpoint, NodeResource> map = (Map) endpoints.stream().collect(Collectors.toMap(drillbitEndpoint -> {
            return drillbitEndpoint;
        }, drillbitEndpoint2 -> {
            return NodeResource.create();
        }));
        QueueQueryParallelizer queueQueryParallelizer = new QueueQueryParallelizer(false, queryContext);
        PlanningSet preparePlanningSet = preparePlanningSet(endpoints, DEFAULT_SLICE_TARGET, map, "SELECT * from cp.`tpch/nation.parquet`", queueQueryParallelizer);
        queueQueryParallelizer.adjustMemory(preparePlanningSet, createSet(preparePlanningSet.getRootWrapper()), endpoints);
        Assert.assertTrue("memory requirement is different", Iterables.all(map.entrySet(), entry -> {
            return ((NodeResource) entry.getValue()).getMemory() == 30;
        }));
    }

    @Test
    public void TestSingleMajorFragmentWithGroupByProjectAndScan() throws Exception {
        List<CoordinationProtos.DrillbitEndpoint> endpoints = getEndpoints(2, new HashSet());
        Map<CoordinationProtos.DrillbitEndpoint, NodeResource> map = (Map) endpoints.stream().collect(Collectors.toMap(drillbitEndpoint -> {
            return drillbitEndpoint;
        }, drillbitEndpoint2 -> {
            return NodeResource.create();
        }));
        QueueQueryParallelizer queueQueryParallelizer = new QueueQueryParallelizer(false, queryContext);
        PlanningSet preparePlanningSet = preparePlanningSet(endpoints, DEFAULT_SLICE_TARGET, map, "SELECT dept_id, count(*) from cp.`tpch/lineitem.parquet` group by dept_id", queueQueryParallelizer);
        queueQueryParallelizer.adjustMemory(preparePlanningSet, createSet(preparePlanningSet.getRootWrapper()), endpoints);
        Assert.assertTrue("memory requirement is different", Iterables.all(map.entrySet(), entry -> {
            return ((NodeResource) entry.getValue()).getMemory() == 529570;
        }));
    }

    @Test
    public void TestTwoMajorFragmentWithSortyProjectAndScan() throws Exception {
        List<CoordinationProtos.DrillbitEndpoint> endpoints = getEndpoints(2, new HashSet());
        Map<CoordinationProtos.DrillbitEndpoint, NodeResource> map = (Map) endpoints.stream().collect(Collectors.toMap(drillbitEndpoint -> {
            return drillbitEndpoint;
        }, drillbitEndpoint2 -> {
            return NodeResource.create();
        }));
        QueueQueryParallelizer queueQueryParallelizer = new QueueQueryParallelizer(false, queryContext);
        PlanningSet preparePlanningSet = preparePlanningSet(endpoints, 2L, map, "SELECT * from cp.`tpch/lineitem.parquet` order by dept_id", queueQueryParallelizer);
        queueQueryParallelizer.adjustMemory(preparePlanningSet, createSet(preparePlanningSet.getRootWrapper()), endpoints);
        Assert.assertTrue("memory requirement is different", Iterables.all(map.entrySet(), entry -> {
            return ((NodeResource) entry.getValue()).getMemory() == 481490;
        }));
    }

    @Test
    public void TestZKBasedQueue() throws Exception {
        ClusterFixture build = ClusterFixture.builder(dirTestWatcher).configProperty(EmbeddedQueryQueue.ENABLED, true).build();
        Throwable th = null;
        try {
            ClientFixture clientFixture = build.clientFixture();
            Throwable th2 = null;
            try {
                clientFixture.queryBuilder().sql("select * from cp.`employee.json`").run();
                if (clientFixture != null) {
                    if (0 != 0) {
                        try {
                            clientFixture.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        clientFixture.close();
                    }
                }
                if (build != null) {
                    if (0 == 0) {
                        build.close();
                        return;
                    }
                    try {
                        build.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                }
            } catch (Throwable th5) {
                if (clientFixture != null) {
                    if (0 != 0) {
                        try {
                            clientFixture.close();
                        } catch (Throwable th6) {
                            th2.addSuppressed(th6);
                        }
                    } else {
                        clientFixture.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (build != null) {
                if (0 != 0) {
                    try {
                        build.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    build.close();
                }
            }
            throw th7;
        }
    }
}
