package org.apache.tez.runtime.library.common;

import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocalDirAllocator;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BoundedByteArrayOutputStream;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.io.serializer.Serializer;
import org.apache.hadoop.util.Progress;
import org.apache.hadoop.util.Progressable;
import org.apache.tez.common.counters.GenericCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.library.common.combine.Combiner;
import org.apache.tez.runtime.library.common.comparator.TezBytesComparator;
import org.apache.tez.runtime.library.common.serializer.SerializationContext;
import org.apache.tez.runtime.library.common.serializer.TezBytesWritableSerialization;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ExceptionReporter;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryReader;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.InMemoryWriter;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.MergeManager;
import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.TestFetcher;
import org.apache.tez.runtime.library.common.sort.impl.IFile;
import org.apache.tez.runtime.library.common.sort.impl.TezMerger;
import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.Mockito;
import org.mockito.internal.util.collections.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/tez/runtime/library/common/TestValuesIterator.class */
public class TestValuesIterator {
    Configuration conf;
    FileSystem fs;
    private SerializationContext serializationContext;
    final RawComparator comparator;
    final RawComparator correctComparator;
    final boolean expectedTestResult;
    int mergeFactor;
    final ListMultimap<Writable, Writable> originalData;
    TezRawKeyValueIterator rawKeyValueIterator;
    Path baseDir;
    Path tmpDir;
    Path[] streamPaths;
    private static final Logger LOG = LoggerFactory.getLogger(TestValuesIterator.class);
    static final String TEZ_BYTES_SERIALIZATION = TezBytesWritableSerialization.class.getName();
    static final Random rnd = new Random();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.tez.runtime.library.common.TestValuesIterator$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/tez/runtime/library/common/TestValuesIterator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator = new int[TestWithComparator.values().length];

        static {
            try {
                $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[TestWithComparator.LONG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[TestWithComparator.INT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[TestWithComparator.BYTES.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[TestWithComparator.TEZ_BYTES.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[TestWithComparator.TEXT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[TestWithComparator.CUSTOM.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:org/apache/tez/runtime/library/common/TestValuesIterator$CustomKey.class */
    public static class CustomKey extends BytesWritable {
        private static final int LENGTH_BYTES = 4;
        private int hashCode;

        /* loaded from: input_file:org/apache/tez/runtime/library/common/TestValuesIterator$CustomKey$Comparator.class */
        public static class Comparator extends WritableComparator {
            public Comparator() {
                super(CustomKey.class);
            }

            public int compare(byte[] bArr, int i, int i2, byte[] bArr2, int i3, int i4) {
                return compareBytes(bArr, i + CustomKey.LENGTH_BYTES, i2 - CustomKey.LENGTH_BYTES, bArr2, i3 + CustomKey.LENGTH_BYTES, i4 - CustomKey.LENGTH_BYTES);
            }
        }

        public CustomKey() {
        }

        public CustomKey(byte[] bArr, int i) {
            super(bArr);
            this.hashCode = i;
        }

        public int hashCode() {
            return this.hashCode;
        }

        static {
            WritableComparator.define(CustomKey.class, new Comparator());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/tez/runtime/library/common/TestValuesIterator$ProgressReporter.class */
    public static class ProgressReporter implements Progressable {
        private ProgressReporter() {
        }

        public void progress() {
        }
    }

    /* loaded from: input_file:org/apache/tez/runtime/library/common/TestValuesIterator$TestWithComparator.class */
    enum TestWithComparator {
        LONG,
        INT,
        BYTES,
        TEZ_BYTES,
        TEXT,
        CUSTOM
    }

    public TestValuesIterator(String str, Class<?> cls, Class<?> cls2, TestWithComparator testWithComparator, TestWithComparator testWithComparator2, boolean z) throws IOException {
        this.comparator = getComparator(testWithComparator);
        this.correctComparator = testWithComparator2 == null ? this.comparator : getComparator(testWithComparator2);
        this.expectedTestResult = z;
        this.originalData = LinkedListMultimap.create();
        setupConf(cls, cls2, str);
    }

    private void setupConf(Class<?> cls, Class<?> cls2, String str) throws IOException {
        this.mergeFactor = 2;
        this.conf = new Configuration();
        this.conf.setInt("tez.runtime.io.sort.factor", this.mergeFactor);
        if (str != null) {
            this.conf.set("io.serializations", str + "," + this.conf.get("io.serializations"));
        }
        this.baseDir = new Path(".", getClass().getName());
        this.conf.setStrings("tez.runtime.framework.local.dirs", new String[]{this.baseDir.toString()});
        this.fs = FileSystem.getLocal(this.conf);
        SerializationFactory serializationFactory = new SerializationFactory(this.conf);
        this.serializationContext = new SerializationContext(cls, cls2, serializationFactory.getSerialization(cls), serializationFactory.getSerialization(cls2));
        this.serializationContext.applyToConf(this.conf);
    }

    @Before
    public void setup() throws Exception {
        this.fs.mkdirs(this.baseDir);
        this.tmpDir = new Path(this.baseDir, "tmp");
    }

    @After
    public void cleanup() throws Exception {
        this.fs.delete(this.baseDir, true);
        this.originalData.clear();
    }

    @Test(timeout = 20000)
    public void testIteratorWithInMemoryReader() throws IOException, InterruptedException {
        verifyIteratorData(createIterator(true));
    }

    @Test(timeout = 20000)
    public void testIteratorWithIFileReader() throws IOException, InterruptedException {
        verifyIteratorData(createIterator(false));
    }

    @Test(timeout = 20000)
    public void testCountedIteratorWithInmemoryReader() throws IOException, InterruptedException {
        verifyCountedIteratorReader(true);
    }

    @Test(timeout = 20000)
    public void testCountedIteratorWithIFileReader() throws IOException, InterruptedException {
        verifyCountedIteratorReader(false);
    }

    private void verifyCountedIteratorReader(boolean z) throws IOException, InterruptedException {
        GenericCounter genericCounter = new GenericCounter("inputKeyCounter", "y3");
        GenericCounter genericCounter2 = new GenericCounter("inputValuesCounter", "y4");
        List<Integer> verifyIteratorData = verifyIteratorData(createCountedIterator(z, genericCounter, genericCounter2));
        if (this.expectedTestResult) {
            Assert.assertEquals(verifyIteratorData.size(), genericCounter.getValue());
            long j = 0;
            Iterator<Integer> it = verifyIteratorData.iterator();
            while (it.hasNext()) {
                j += it.next().longValue();
            }
            Assert.assertEquals(j, genericCounter2.getValue());
        }
    }

    @Test(timeout = 20000)
    public void testIteratorWithIFileReaderEmptyPartitions() throws IOException, InterruptedException {
        Assert.assertTrue(!createEmptyIterator(false).moveToNext());
        Assert.assertTrue(!createEmptyIterator(true).moveToNext());
    }

    private void getNextFromFinishedIterator(ValuesIterator valuesIterator) {
        try {
            valuesIterator.moveToNext();
            Assert.fail();
        } catch (IOException e) {
            Assert.assertTrue(e.getMessage().contains("Please check if you are invoking moveToNext()"));
        }
    }

    private ValuesIterator createEmptyIterator(boolean z) throws IOException, InterruptedException {
        if (z) {
            this.rawKeyValueIterator = TezMerger.merge(this.conf, this.fs, this.serializationContext, Lists.newLinkedList(), this.mergeFactor, this.tmpDir, this.comparator, new ProgressReporter(), new GenericCounter("readsCounter", "y"), new GenericCounter("writesCounter", "y1"), new GenericCounter("bytesReadCounter", "y2"), new Progress());
        } else {
            this.streamPaths = new Path[0];
            this.rawKeyValueIterator = TezMerger.merge(this.conf, this.fs, this.serializationContext, (CompressionCodec) null, false, -1, 1024, this.streamPaths, false, this.mergeFactor, this.tmpDir, this.comparator, new ProgressReporter(), (TezCounter) null, (TezCounter) null, (TezCounter) null, (Progress) null);
        }
        return new ValuesIterator(this.rawKeyValueIterator, this.comparator, this.serializationContext.getKeyClass(), this.serializationContext.getValueClass(), this.conf, new GenericCounter("inputKeyCounter", "y3"), new GenericCounter("inputValueCounter", "y4"));
    }

    private List<Integer> verifyIteratorData(ValuesIterator valuesIterator) throws IOException {
        boolean z = true;
        ArrayList arrayList = new ArrayList();
        ImmutableListMultimap build = new ImmutableListMultimap.Builder().orderKeysBy(this.correctComparator).putAll(this.originalData).build();
        Set newSet = Sets.newSet(new Map.Entry[0]);
        newSet.addAll(build.entries());
        Iterator it = newSet.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Map.Entry entry = (Map.Entry) it.next();
            Assert.assertTrue(valuesIterator.moveToNext());
            Writable writable = (Writable) entry.getKey();
            if (!writable.equals((Writable) valuesIterator.getKey())) {
                z = false;
                break;
            }
            int i = 0;
            Iterator it2 = valuesIterator.getValues().iterator();
            Iterator it3 = build.get(writable).iterator();
            while (true) {
                if (!it3.hasNext()) {
                    break;
                }
                Writable writable2 = (Writable) it3.next();
                Assert.assertTrue(it2.hasNext());
                if (!writable2.equals((Writable) it2.next())) {
                    z = false;
                    break;
                }
                i++;
            }
            arrayList.add(Integer.valueOf(i));
            Assert.assertTrue("At least 1 value per key", i > 0);
        }
        if (this.expectedTestResult) {
            Assert.assertTrue(z);
            Assert.assertFalse(valuesIterator.moveToNext());
            getNextFromFinishedIterator(valuesIterator);
            return arrayList;
        }
        do {
        } while (valuesIterator.moveToNext());
        getNextFromFinishedIterator(valuesIterator);
        Assert.assertFalse(z);
        return arrayList;
    }

    private ValuesIterator createIterator(boolean z) throws IOException, InterruptedException {
        if (z) {
            this.rawKeyValueIterator = TezMerger.merge(this.conf, this.fs, this.serializationContext, createInMemStreams(), this.mergeFactor, this.tmpDir, this.comparator, new ProgressReporter(), new GenericCounter("readsCounter", "y"), new GenericCounter("writesCounter", "y1"), new GenericCounter("bytesReadCounter", "y2"), new Progress());
        } else {
            this.streamPaths = createFiles();
            this.rawKeyValueIterator = TezMerger.merge(this.conf, this.fs, this.serializationContext, (CompressionCodec) null, false, -1, 1024, this.streamPaths, false, this.mergeFactor, this.tmpDir, this.comparator, new ProgressReporter(), (TezCounter) null, (TezCounter) null, (TezCounter) null, (Progress) null);
        }
        return new ValuesIterator(this.rawKeyValueIterator, this.comparator, this.serializationContext.getKeyClass(), this.serializationContext.getValueClass(), this.conf, new GenericCounter("inputKeyCounter", "y3"), new GenericCounter("inputValueCounter", "y4"));
    }

    private ValuesIterator createCountedIterator(boolean z, TezCounter tezCounter, TezCounter tezCounter2) throws IOException, InterruptedException {
        if (z) {
            this.rawKeyValueIterator = TezMerger.merge(this.conf, this.fs, this.serializationContext, createInMemStreams(), this.mergeFactor, this.tmpDir, this.comparator, new ProgressReporter(), new GenericCounter("readsCounter", "y"), new GenericCounter("writesCounter", "y1"), new GenericCounter("bytesReadCounter", "y2"), new Progress());
        } else {
            this.streamPaths = createFiles();
            this.rawKeyValueIterator = TezMerger.merge(this.conf, this.fs, this.serializationContext, (CompressionCodec) null, false, -1, 1024, this.streamPaths, false, this.mergeFactor, this.tmpDir, this.comparator, new ProgressReporter(), (TezCounter) null, (TezCounter) null, (TezCounter) null, (Progress) null);
        }
        return new ValuesIterator(this.rawKeyValueIterator, this.comparator, this.serializationContext.getKeyClass(), this.serializationContext.getValueClass(), this.conf, tezCounter, tezCounter2);
    }

    @Parameterized.Parameters(name = "test[{0}, {1}, {2}, {3} {4} {5} {6}]")
    public static Collection<Object[]> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Object[]{null, Text.class, Text.class, TestWithComparator.TEXT, null, true});
        arrayList.add(new Object[]{null, LongWritable.class, Text.class, TestWithComparator.LONG, null, true});
        arrayList.add(new Object[]{null, IntWritable.class, Text.class, TestWithComparator.INT, null, true});
        arrayList.add(new Object[]{null, BytesWritable.class, BytesWritable.class, TestWithComparator.BYTES, null, true});
        arrayList.add(new Object[]{TEZ_BYTES_SERIALIZATION, BytesWritable.class, BytesWritable.class, TestWithComparator.TEZ_BYTES, null, true});
        arrayList.add(new Object[]{TEZ_BYTES_SERIALIZATION, BytesWritable.class, LongWritable.class, TestWithComparator.TEZ_BYTES, null, true});
        arrayList.add(new Object[]{TEZ_BYTES_SERIALIZATION, CustomKey.class, LongWritable.class, TestWithComparator.TEZ_BYTES, null, true});
        arrayList.add(new Object[]{TEZ_BYTES_SERIALIZATION, BytesWritable.class, BytesWritable.class, TestWithComparator.BYTES, TestWithComparator.TEZ_BYTES, false});
        arrayList.add(new Object[]{TEZ_BYTES_SERIALIZATION, CustomKey.class, LongWritable.class, TestWithComparator.CUSTOM, TestWithComparator.TEZ_BYTES, false});
        return arrayList;
    }

    private RawComparator getComparator(TestWithComparator testWithComparator) {
        switch (AnonymousClass1.$SwitchMap$org$apache$tez$runtime$library$common$TestValuesIterator$TestWithComparator[testWithComparator.ordinal()]) {
            case TestFetcher.DAG_ID /* 1 */:
                return new LongWritable.Comparator();
            case 2:
                return new IntWritable.Comparator();
            case 3:
                return new BytesWritable.Comparator();
            case 4:
                return new TezBytesComparator();
            case 5:
                return new Text.Comparator();
            case 6:
                return new CustomKey.Comparator();
            default:
                return null;
        }
    }

    private Path[] createFiles() throws IOException {
        int max = Math.max(2, rnd.nextInt(10));
        this.mergeFactor = Math.max(this.mergeFactor, max);
        LOG.info("No of streams : " + max);
        Path[] pathArr = new Path[max];
        for (int i = 0; i < max; i++) {
            pathArr[i] = new Path(this.baseDir, "ifile_" + i + ".out");
            FSDataOutputStream create = this.fs.create(pathArr[i]);
            IFile.Writer writer = new IFile.Writer(this.serializationContext.getKeySerialization(), this.serializationContext.getValSerialization(), create, this.serializationContext.getKeyClass(), this.serializationContext.getValueClass(), (CompressionCodec) null, (TezCounter) null, (TezCounter) null, true);
            Map<Writable, Writable> createData = createData();
            for (Map.Entry<Writable, Writable> entry : createData.entrySet()) {
                writer.append(entry.getKey(), entry.getValue());
                this.originalData.put(entry.getKey(), entry.getValue());
                if (rnd.nextInt() % 2 == 0) {
                    for (int i2 = 0; i2 < rnd.nextInt(100); i2++) {
                        writer.append(entry.getKey(), entry.getValue());
                        this.originalData.put(entry.getKey(), entry.getValue());
                    }
                }
            }
            LOG.info("Wrote " + createData.size() + " in " + pathArr[i]);
            createData.clear();
            writer.close();
            create.close();
        }
        return pathArr;
    }

    public List<TezMerger.Segment> createInMemStreams() throws IOException {
        int max = Math.max(2, rnd.nextInt(10));
        LOG.info("No of streams : " + max);
        Serializer keySerializer = this.serializationContext.getKeySerializer();
        Serializer valueSerializer = this.serializationContext.getValueSerializer();
        MergeManager mergeManager = new MergeManager(this.conf, this.fs, new LocalDirAllocator("tez.runtime.framework.local.dirs"), createTezInputContext(), (Combiner) null, (TezCounter) null, (TezCounter) null, (TezCounter) null, (ExceptionReporter) null, 10485760L, (CompressionCodec) null, false, -1);
        DataOutputBuffer dataOutputBuffer = new DataOutputBuffer();
        DataOutputBuffer dataOutputBuffer2 = new DataOutputBuffer();
        DataInputBuffer dataInputBuffer = new DataInputBuffer();
        DataInputBuffer dataInputBuffer2 = new DataInputBuffer();
        keySerializer.open(dataOutputBuffer);
        valueSerializer.open(dataOutputBuffer2);
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < max; i++) {
            BoundedByteArrayOutputStream boundedByteArrayOutputStream = new BoundedByteArrayOutputStream(1048576);
            InMemoryWriter inMemoryWriter = new InMemoryWriter(boundedByteArrayOutputStream);
            Map<Writable, Writable> createData = createData();
            for (Map.Entry<Writable, Writable> entry : createData.entrySet()) {
                keySerializer.serialize(entry.getKey());
                valueSerializer.serialize(entry.getValue());
                dataInputBuffer.reset(dataOutputBuffer.getData(), 0, dataOutputBuffer.getLength());
                dataInputBuffer2.reset(dataOutputBuffer2.getData(), 0, dataOutputBuffer2.getLength());
                inMemoryWriter.append(dataInputBuffer, dataInputBuffer2);
                this.originalData.put(entry.getKey(), entry.getValue());
                dataOutputBuffer.reset();
                dataOutputBuffer2.reset();
                dataInputBuffer.reset();
                dataInputBuffer2.reset();
            }
            linkedList.add(new TezMerger.Segment(new InMemoryReader(mergeManager, (InputAttemptIdentifier) null, boundedByteArrayOutputStream.getBuffer(), 0, boundedByteArrayOutputStream.getBuffer().length), (TezCounter) null));
            createData.clear();
            inMemoryWriter.close();
        }
        return linkedList;
    }

    private InputContext createTezInputContext() {
        TezCounters tezCounters = new TezCounters();
        InputContext inputContext = (InputContext) Mockito.mock(InputContext.class);
        ((InputContext) Mockito.doReturn(104857600L).when(inputContext)).getTotalMemoryAvailableToTask();
        ((InputContext) Mockito.doReturn(tezCounters).when(inputContext)).getCounters();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getInputIndex();
        ((InputContext) Mockito.doReturn("srcVertex").when(inputContext)).getSourceVertexName();
        ((InputContext) Mockito.doReturn(1).when(inputContext)).getTaskVertexIndex();
        ((InputContext) Mockito.doReturn(UserPayload.create(ByteBuffer.wrap(new byte[1024]))).when(inputContext)).getUserPayload();
        return inputContext;
    }

    private Map<Writable, Writable> createData() {
        TreeMap treeMap = new TreeMap((Comparator) this.comparator);
        for (int i = 0; i < Math.max(10, rnd.nextInt(50)); i++) {
            treeMap.put(createData(this.serializationContext.getKeyClass()), createData(this.serializationContext.getValueClass()));
        }
        return treeMap;
    }

    private Writable createData(Class<?> cls) {
        if (cls.getName().equalsIgnoreCase(BytesWritable.class.getName())) {
            return new BytesWritable(new BigInteger(256, rnd).toString().getBytes());
        }
        if (cls.getName().equalsIgnoreCase(IntWritable.class.getName())) {
            return new IntWritable(rnd.nextInt());
        }
        if (cls.getName().equalsIgnoreCase(LongWritable.class.getName())) {
            return new LongWritable(rnd.nextLong());
        }
        if (cls.getName().equalsIgnoreCase(CustomKey.class.getName())) {
            String str = new BigInteger(256, rnd).toString() + "_" + new BigInteger(256, rnd).toString();
            return new CustomKey(str.getBytes(), str.hashCode());
        }
        if (cls.getName().equalsIgnoreCase(Text.class.getName())) {
            return new Text(new BigInteger(256, rnd).toString() + "_" + new BigInteger(256, rnd).toString());
        }
        throw new IllegalArgumentException("Illegal argument : " + cls.getName());
    }
}
