package org.apache.tez.runtime.library.resources;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.runtime.common.resources.InitialMemoryAllocator;
import org.apache.tez.runtime.common.resources.InitialMemoryRequestContext;
import org.apache.tez.runtime.library.input.OrderedGroupedInputLegacy;
import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.UnorderedKVInput;
import org.apache.tez.runtime.library.output.OrderedPartitionedKVOutput;
import org.apache.tez.runtime.library.output.UnorderedPartitionedKVOutput;

@InterfaceAudience.Public
@InterfaceStability.Unstable
/* loaded from: input_file:org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor.class */
public class WeightedScalingMemoryDistributor implements InitialMemoryAllocator {
    static final double MAX_ADDITIONAL_RESERVATION_FRACTION_PER_IO = 0.1d;
    static final double RESERVATION_FRACTION_PER_IO = 0.015d;
    private Configuration conf;
    private EnumMap<RequestType, Integer> typeScaleMap = Maps.newEnumMap(RequestType.class);
    private int numRequests = 0;
    private int numRequestsScaled = 0;
    private long totalRequested = 0;
    private List<Request> requests = Lists.newArrayList();
    private static final Log LOG = LogFactory.getLog(WeightedScalingMemoryDistributor.class);
    static final String[] DEFAULT_TASK_MEMORY_WEIGHTED_RATIOS = generateWeightStrings(1, 1, 12, 12, 1, 1);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor$Request.class */
    public static class Request {
        String componentClassname;
        long requestSize;
        private RequestType requestType;
        private int requestWeight;

        Request(String str, long j, RequestType requestType, int i) {
            this.componentClassname = str;
            this.requestSize = j;
            this.requestType = requestType;
            this.requestWeight = i;
        }
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    /* loaded from: input_file:org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor$RequestType.class */
    public enum RequestType {
        PARTITIONED_UNSORTED_OUTPUT,
        UNSORTED_INPUT,
        UNSORTED_OUTPUT,
        SORTED_OUTPUT,
        SORTED_MERGED_INPUT,
        PROCESSOR,
        OTHER
    }

    public Iterable<Long> assignMemory(long j, int i, int i2, Iterable<InitialMemoryRequestContext> iterable) {
        populateTypeScaleMap();
        Iterator<InitialMemoryRequestContext> it = iterable.iterator();
        while (it.hasNext()) {
            initialProcessMemoryRequestContext(it.next());
        }
        if (this.numRequestsScaled == 0) {
            this.numRequestsScaled = this.numRequests;
            Iterator<Request> it2 = this.requests.iterator();
            while (it2.hasNext()) {
                it2.next().requestWeight = 1;
            }
        }
        double d = 0.0d;
        for (Request request : this.requests) {
            d += request.requestSize * (request.requestWeight / this.numRequestsScaled);
        }
        double computeReservedFraction = computeReservedFraction(this.numRequests);
        Preconditions.checkState(computeReservedFraction >= 0.0d && computeReservedFraction <= 1.0d);
        long j2 = (long) (j - (computeReservedFraction * j));
        long maxMemory = Runtime.getRuntime().maxMemory();
        LOG.info("Scaling Requests. NumRequests: " + this.numRequests + ", numScaledRequests: " + this.numRequestsScaled + ", TotalRequested: " + this.totalRequested + ", TotalRequestedScaled: " + d + ", TotalJVMHeap: " + maxMemory + ", TotalAvailable: " + j2 + ", TotalRequested/TotalJVMHeap:" + new DecimalFormat("0.00").format(this.totalRequested / maxMemory));
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(this.numRequests);
        for (Request request2 : this.requests) {
            if (request2.requestSize == 0) {
                newArrayListWithCapacity.add(0L);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Scaling requested " + request2.componentClassname + " of type " + request2.requestType + " 0 to allocated: 0");
                }
            } else {
                long min = Math.min((long) ((((request2.requestWeight / this.numRequestsScaled) * request2.requestSize) / d) * j2), request2.requestSize);
                newArrayListWithCapacity.add(Long.valueOf(min));
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Scaling requested " + request2.componentClassname + " of type " + request2.requestType + " " + request2.requestSize + "  to allocated: " + min);
                }
            }
        }
        return newArrayListWithCapacity;
    }

    private void initialProcessMemoryRequestContext(InitialMemoryRequestContext initialMemoryRequestContext) {
        this.numRequests++;
        this.totalRequested += initialMemoryRequestContext.getRequestedSize();
        RequestType requestTypeForClass = getRequestTypeForClass(initialMemoryRequestContext.getComponentClassName());
        Integer scaleFactorForType = getScaleFactorForType(requestTypeForClass);
        this.requests.add(new Request(initialMemoryRequestContext.getComponentClassName(), initialMemoryRequestContext.getRequestedSize(), requestTypeForClass, scaleFactorForType.intValue()));
        LOG.info("ScaleFactor: " + scaleFactorForType + ", for type: " + requestTypeForClass);
        this.numRequestsScaled += scaleFactorForType.intValue();
    }

    private Integer getScaleFactorForType(RequestType requestType) {
        Integer num = this.typeScaleMap.get(requestType);
        if (num == null) {
            LOG.warn("Bad scale factor for requestType: " + requestType + ", Using factor 0");
            num = 0;
        }
        return num;
    }

    private RequestType getRequestTypeForClass(String str) {
        RequestType requestType;
        if (str.equals(OrderedPartitionedKVOutput.class.getName())) {
            requestType = RequestType.SORTED_OUTPUT;
        } else if (str.equals(OrderedGroupedKVInput.class.getName()) || str.equals(OrderedGroupedInputLegacy.class.getName())) {
            requestType = RequestType.SORTED_MERGED_INPUT;
        } else if (str.equals(UnorderedKVInput.class.getName())) {
            requestType = RequestType.UNSORTED_INPUT;
        } else if (str.equals(UnorderedPartitionedKVOutput.class.getName())) {
            requestType = RequestType.PARTITIONED_UNSORTED_OUTPUT;
        } else {
            requestType = RequestType.OTHER;
            LOG.info("Falling back to RequestType.OTHER for class: " + str);
        }
        return requestType;
    }

    private void populateTypeScaleMap() {
        String[] strings = this.conf.getStrings("tez.task.scale.memory.ratios", DEFAULT_TASK_MEMORY_WEIGHTED_RATIOS);
        int length = RequestType.values().length;
        if (strings == null) {
            LOG.info("No ratio specified. Falling back to Linear scaling");
            strings = new String[length];
            int i = 0;
            for (RequestType requestType : RequestType.values()) {
                strings[i] = requestType.name() + ":1";
                i++;
            }
        } else if (strings.length != RequestType.values().length) {
            throw new IllegalArgumentException("Number of entries in the configured ratios should be equal to the number of entries in RequestType: " + length);
        }
        HashSet hashSet = new HashSet();
        for (String str : strings) {
            String[] split = str.split(":");
            Preconditions.checkState(split.length == 2);
            RequestType valueOf = RequestType.valueOf(split[0]);
            Integer valueOf2 = Integer.valueOf(Integer.parseInt(split[1]));
            if (!hashSet.add(valueOf)) {
                throw new IllegalArgumentException("Cannot configure the same RequestType: " + valueOf + " multiple times");
            }
            Preconditions.checkState(valueOf2.intValue() >= 0, "Ratio must be >= 0");
            this.typeScaleMap.put((EnumMap<RequestType, Integer>) valueOf, (RequestType) valueOf2);
        }
    }

    private double computeReservedFraction(int i) {
        double d = this.conf.getDouble("tez.task.scale.memory.additional-reservation.fraction.per-io", RESERVATION_FRACTION_PER_IO);
        double d2 = this.conf.getDouble("tez.task.scale.memory.additional-reservation.fraction.max", MAX_ADDITIONAL_RESERVATION_FRACTION_PER_IO);
        Preconditions.checkArgument(d2 >= 0.0d && d2 <= 1.0d);
        Preconditions.checkArgument(d <= d2 && d >= 0.0d);
        if (LOG.isDebugEnabled()) {
            LOG.debug("ReservationFractionPerIO=" + d + ", MaxPerIOReserveFraction=" + d2);
        }
        double d3 = this.conf.getDouble("tez.task.scale.memory.reserve-fraction", 0.3d);
        double min = Math.min(d2, i * d);
        double d4 = d3 + min;
        Preconditions.checkState(d4 <= 1.0d);
        LOG.info("InitialReservationFraction=" + d3 + ", AdditionalReservationFractionForIOs=" + min + ", finalReserveFractionUsed=" + d4);
        return d4;
    }

    public static String[] generateWeightStrings(int i, int i2, int i3, int i4, int i5, int i6) {
        String[] strArr = new String[RequestType.values().length];
        strArr[0] = RequestType.PARTITIONED_UNSORTED_OUTPUT.name() + ":" + i;
        strArr[1] = RequestType.UNSORTED_OUTPUT.name() + ":0";
        strArr[2] = RequestType.UNSORTED_INPUT.name() + ":" + i2;
        strArr[3] = RequestType.SORTED_OUTPUT.name() + ":" + i3;
        strArr[4] = RequestType.SORTED_MERGED_INPUT.name() + ":" + i4;
        strArr[5] = RequestType.PROCESSOR.name() + ":" + i5;
        strArr[6] = RequestType.OTHER.name() + ":" + i6;
        return strArr;
    }

    public void setConf(Configuration configuration) {
        this.conf = configuration;
    }

    public Configuration getConf() {
        return this.conf;
    }
}
