package org.apache.tez.runtime.common.resources;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.ReflectionUtils;
import org.apache.tez.dag.api.EntityDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.MemoryUpdateCallback;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.api.ProcessorContext;
import org.apache.tez.runtime.api.TaskContext;
import org.apache.tez.runtime.common.resources.InitialMemoryRequestContext;
import org.apache.tez.runtime.internals.api.events.SystemEventProtos;

@InterfaceAudience.Private
/* loaded from: input_file:org/apache/tez/runtime/common/resources/MemoryDistributor.class */
public class MemoryDistributor {
    private static final Log LOG = LogFactory.getLog(MemoryDistributor.class);
    private final int numTotalInputs;
    private final int numTotalOutputs;
    private final Configuration conf;
    private final boolean isEnabled;
    private AtomicInteger numInputsSeen = new AtomicInteger(0);
    private AtomicInteger numOutputsSeen = new AtomicInteger(0);
    private final Set<TaskContext> dupSet = Collections.newSetFromMap(new ConcurrentHashMap());
    private long totalJvmMemory = Runtime.getRuntime().maxMemory();
    private final List<RequestorInfo> requestList = Collections.synchronizedList(new LinkedList());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.tez.runtime.common.resources.MemoryDistributor$3, reason: invalid class name */
    /* loaded from: input_file:org/apache/tez/runtime/common/resources/MemoryDistributor$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$tez$runtime$common$resources$InitialMemoryRequestContext$ComponentType = new int[InitialMemoryRequestContext.ComponentType.values().length];

        static {
            try {
                $SwitchMap$org$apache$tez$runtime$common$resources$InitialMemoryRequestContext$ComponentType[InitialMemoryRequestContext.ComponentType.INPUT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$common$resources$InitialMemoryRequestContext$ComponentType[InitialMemoryRequestContext.ComponentType.OUTPUT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$common$resources$InitialMemoryRequestContext$ComponentType[InitialMemoryRequestContext.ComponentType.PROCESSOR.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/tez/runtime/common/resources/MemoryDistributor$RequestorInfo.class */
    public static class RequestorInfo {
        private static final Log LOG = LogFactory.getLog(RequestorInfo.class);
        private final MemoryUpdateCallback callback;
        private final InitialMemoryRequestContext requestContext;

        public RequestorInfo(TaskContext taskContext, long j, MemoryUpdateCallback memoryUpdateCallback, EntityDescriptor<?> entityDescriptor) {
            InitialMemoryRequestContext.ComponentType componentType;
            String taskVertexName;
            if (taskContext instanceof InputContext) {
                componentType = InitialMemoryRequestContext.ComponentType.INPUT;
                taskVertexName = ((InputContext) taskContext).getSourceVertexName();
            } else if (taskContext instanceof OutputContext) {
                componentType = InitialMemoryRequestContext.ComponentType.OUTPUT;
                taskVertexName = ((OutputContext) taskContext).getDestinationVertexName();
            } else {
                if (!(taskContext instanceof ProcessorContext)) {
                    throw new IllegalArgumentException("Unknown type of entityContext: " + taskContext.getClass().getName());
                }
                componentType = InitialMemoryRequestContext.ComponentType.PROCESSOR;
                taskVertexName = ((ProcessorContext) taskContext).getTaskVertexName();
            }
            this.requestContext = new InitialMemoryRequestContext(j, entityDescriptor.getClassName(), componentType, taskVertexName);
            this.callback = memoryUpdateCallback;
            LOG.info("Received request: " + j + ", type: " + componentType + ", componentVertexName: " + taskVertexName);
        }

        public MemoryUpdateCallback getCallback() {
            return this.callback;
        }

        public InitialMemoryRequestContext getRequestContext() {
            return this.requestContext;
        }
    }

    public MemoryDistributor(int i, int i2, Configuration configuration) {
        this.conf = configuration;
        this.isEnabled = configuration.getBoolean("tez.task.scale.memory.enabled", true);
        this.numTotalInputs = i;
        this.numTotalOutputs = i2;
        LOG.info("InitialMemoryDistributor (isEnabled=" + this.isEnabled + ") invoked with: numInputs=" + i + ", numOutputs=" + i2 + ", JVM.maxFree=" + this.totalJvmMemory);
    }

    public void requestMemory(long j, MemoryUpdateCallback memoryUpdateCallback, TaskContext taskContext, EntityDescriptor<?> entityDescriptor) {
        registerRequest(j, memoryUpdateCallback, taskContext, entityDescriptor);
    }

    public void makeInitialAllocations() {
        Iterable<Long> assignMemory;
        Preconditions.checkState(this.numInputsSeen.get() == this.numTotalInputs, "All inputs are expected to ask for memory");
        Preconditions.checkState(this.numOutputsSeen.get() == this.numTotalOutputs, "All outputs are expected to ask for memory");
        Iterable transform = Iterables.transform(this.requestList, new Function<RequestorInfo, InitialMemoryRequestContext>() { // from class: org.apache.tez.runtime.common.resources.MemoryDistributor.1
            public InitialMemoryRequestContext apply(RequestorInfo requestorInfo) {
                return requestorInfo.getRequestContext();
            }
        });
        if (this.isEnabled) {
            String str = this.conf.get("tez.task.scale.memory.allocator.class", "org.apache.tez.runtime.library.resources.WeightedScalingMemoryDistributor");
            LOG.info("Using Allocator class: " + str);
            InitialMemoryAllocator initialMemoryAllocator = (InitialMemoryAllocator) ReflectionUtils.createClazzInstance(str);
            initialMemoryAllocator.setConf(this.conf);
            assignMemory = initialMemoryAllocator.assignMemory(this.totalJvmMemory, this.numTotalInputs, this.numTotalOutputs, Iterables.unmodifiableIterable(transform));
            validateAllocations(assignMemory, this.requestList.size());
        } else {
            assignMemory = Iterables.transform(this.requestList, new Function<RequestorInfo, Long>() { // from class: org.apache.tez.runtime.common.resources.MemoryDistributor.2
                public Long apply(RequestorInfo requestorInfo) {
                    return Long.valueOf(requestorInfo.getRequestContext().getRequestedSize());
                }
            });
        }
        Iterator<Long> it = assignMemory.iterator();
        for (RequestorInfo requestorInfo : this.requestList) {
            long longValue = it.next().longValue();
            LOG.info("Informing: " + requestorInfo.getRequestContext().getComponentType() + ", " + requestorInfo.getRequestContext().getComponentVertexName() + ", " + requestorInfo.getRequestContext().getComponentClassName() + ": requested=" + requestorInfo.getRequestContext().getRequestedSize() + ", allocated=" + longValue);
            requestorInfo.getCallback().memoryAssigned(longValue);
        }
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    void setJvmMemory(long j) {
        this.totalJvmMemory = j;
    }

    private long registerRequest(long j, MemoryUpdateCallback memoryUpdateCallback, TaskContext taskContext, EntityDescriptor<?> entityDescriptor) {
        Preconditions.checkArgument(j >= 0);
        Preconditions.checkNotNull(memoryUpdateCallback);
        Preconditions.checkNotNull(taskContext);
        Preconditions.checkNotNull(entityDescriptor);
        if (!this.dupSet.add(taskContext)) {
            throw new TezUncheckedException("A single entity can only make one call to request resources for now");
        }
        RequestorInfo requestorInfo = new RequestorInfo(taskContext, j, memoryUpdateCallback, entityDescriptor);
        switch (AnonymousClass3.$SwitchMap$org$apache$tez$runtime$common$resources$InitialMemoryRequestContext$ComponentType[requestorInfo.getRequestContext().getComponentType().ordinal()]) {
            case SystemEventProtos.TaskAttemptFailedEventProto.DIAGNOSTICS_FIELD_NUMBER /* 1 */:
                this.numInputsSeen.incrementAndGet();
                Preconditions.checkState(this.numInputsSeen.get() <= this.numTotalInputs, "Num Requesting Inputs higher than total # of inputs: " + this.numInputsSeen + ", " + this.numTotalInputs);
                break;
            case 2:
                this.numOutputsSeen.incrementAndGet();
                Preconditions.checkState(this.numOutputsSeen.get() <= this.numTotalOutputs, "Num Requesting Inputs higher than total # of outputs: " + this.numOutputsSeen + ", " + this.numTotalOutputs);
                break;
        }
        this.requestList.add(requestorInfo);
        return -1L;
    }

    private void validateAllocations(Iterable<Long> iterable, int i) {
        Preconditions.checkNotNull(iterable);
        long j = 0;
        int i2 = 0;
        Iterator<Long> it = iterable.iterator();
        while (it.hasNext()) {
            j += it.next().longValue();
            i2++;
        }
        Preconditions.checkState(i2 == i, "Number of allocations must match number of requestors. Allocated=" + i2 + ", Requests: " + i);
        Preconditions.checkState(j <= this.totalJvmMemory, "Total allocation should be <= availableMem. TotalAllocated: " + j + ", totalJvmMemory: " + this.totalJvmMemory);
    }
}
