package org.apache.tez.runtime.common.resources;

import com.google.common.base.Joiner;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.runtime.api.LogicalInput;
import org.apache.tez.runtime.api.LogicalOutput;
import org.apache.tez.runtime.api.MemoryUpdateCallback;
import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.UnorderedKVInput;
import org.apache.tez.runtime.library.output.OrderedPartitionedKVOutput;
import org.apache.tez.runtime.library.resources.WeightedScalingMemoryDistributor;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/tez/runtime/common/resources/TestWeightedScalingMemoryDistributor.class */
public class TestWeightedScalingMemoryDistributor extends TestMemoryDistributor {

    /* loaded from: input_file:org/apache/tez/runtime/common/resources/TestWeightedScalingMemoryDistributor$MemoryUpdateCallbackForTest.class */
    private static class MemoryUpdateCallbackForTest extends MemoryUpdateCallback {
        long assigned;

        private MemoryUpdateCallbackForTest() {
            this.assigned = -1000L;
        }

        public void memoryAssigned(long j) {
            this.assigned = j;
        }
    }

    public void setup() {
        this.conf.setBoolean("tez.task.scale.memory.enabled", true);
        this.conf.set("tez.task.scale.memory.allocator.class", WeightedScalingMemoryDistributor.class.getName());
        this.conf.setDouble("tez.task.scale.memory.reserve-fraction", 0.3d);
        this.conf.setDouble("tez.task.scale.memory.additional-reservation.fraction.per-io", 0.0d);
    }

    @Test(timeout = 5000)
    public void testSimpleWeightedScaling() throws TezException {
        Configuration configuration = new Configuration(this.conf);
        configuration.setStrings("tez.task.scale.memory.ratios", WeightedScalingMemoryDistributor.generateWeightStrings(0, 0, 1, 2, 3, 1, 1));
        System.err.println(Joiner.on(",").join(configuration.getStringCollection("tez.task.scale.memory.ratios")));
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 2, configuration);
        memoryDistributor.setJvmMemory(10000L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor(OrderedGroupedKVInput.class));
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor(UnorderedKVInput.class));
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest3 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest3, createTestOutputContext(), createTestOutputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest4 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest4, createTestOutputContext(), createTestOutputDescriptor(OrderedPartitionedKVOutput.class));
        memoryDistributor.makeInitialAllocations();
        Assert.assertEquals(3000L, memoryUpdateCallbackForTest.assigned);
        Assert.assertEquals(1000L, memoryUpdateCallbackForTest2.assigned);
        Assert.assertEquals(1000L, memoryUpdateCallbackForTest3.assigned);
        Assert.assertEquals(2000L, memoryUpdateCallbackForTest4.assigned);
    }

    @Test(timeout = 5000)
    public void testAdditionalReserveFractionWeightedScaling() throws TezException {
        Configuration configuration = new Configuration(this.conf);
        configuration.setStrings("tez.task.scale.memory.ratios", WeightedScalingMemoryDistributor.generateWeightStrings(0, 0, 2, 3, 6, 1, 1));
        configuration.setDouble("tez.task.scale.memory.additional-reservation.fraction.per-io", 0.025d);
        configuration.setDouble("tez.task.scale.memory.additional-reservation.fraction.max", 0.2d);
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 2, configuration);
        memoryDistributor.setJvmMemory(10000L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor(OrderedGroupedKVInput.class));
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor(UnorderedKVInput.class));
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest3 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest3, createTestOutputContext(), createTestOutputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest4 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest4, createTestOutputContext(), createTestOutputDescriptor(OrderedPartitionedKVOutput.class));
        memoryDistributor.makeInitialAllocations();
        Assert.assertEquals(3000L, memoryUpdateCallbackForTest.assigned);
        Assert.assertEquals(1000L, memoryUpdateCallbackForTest2.assigned);
        Assert.assertEquals(500L, memoryUpdateCallbackForTest3.assigned);
        Assert.assertEquals(1500L, memoryUpdateCallbackForTest4.assigned);
    }

    private InputDescriptor createTestInputDescriptor(Class<? extends LogicalInput> cls) {
        InputDescriptor inputDescriptor = (InputDescriptor) Mockito.mock(InputDescriptor.class);
        ((InputDescriptor) Mockito.doReturn(cls.getName()).when(inputDescriptor)).getClassName();
        return inputDescriptor;
    }

    private OutputDescriptor createTestOutputDescriptor(Class<? extends LogicalOutput> cls) {
        OutputDescriptor outputDescriptor = (OutputDescriptor) Mockito.mock(OutputDescriptor.class);
        ((OutputDescriptor) Mockito.doReturn(cls.getName()).when(outputDescriptor)).getClassName();
        return outputDescriptor;
    }
}
