/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.GroupReduceDriver;
import org.apache.flink.runtime.operators.ReduceTaskExternalITCase;
import org.apache.flink.runtime.operators.sort.ExternalSorter;
import org.apache.flink.runtime.operators.testutils.DelayingInfinitiveInputIterator;
import org.apache.flink.runtime.operators.testutils.DriverTestBase;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.NirvanaOutputList;
import org.apache.flink.runtime.operators.testutils.TaskCancelThread;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.testutils.recordutils.RecordComparator;
import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AtomicBooleanAssert;
import org.assertj.core.api.ListAssert;
import org.junit.jupiter.api.TestTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReduceTaskTest
extends DriverTestBase<GroupReduceFunction<Record, Record>> {
    private static final Logger LOG = LoggerFactory.getLogger(ReduceTaskTest.class);
    private final RecordComparator comparator = new RecordComparator(new int[]{0}, new Class[]{IntValue.class});
    private final List<Record> outList = new ArrayList<Record>();

    public ReduceTaskTest(ExecutionConfig config) {
        super(config, 0L, 1, 0x300000L);
    }

    @TestTemplate
    void testReduceTaskWithSortingInput() {
        int keyCnt = 100;
        int valCnt = 20;
        this.addDriverComparator(this.comparator);
        this.setOutput(this.outList);
        this.getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
        try {
            this.addInputSorted(new UniformRecordGenerator(100, 20, false), this.comparator.duplicate());
            GroupReduceDriver testTask = new GroupReduceDriver();
            this.testDriver((Driver)testTask, MockReduceStub.class);
        }
        catch (Exception e) {
            LOG.info("Exception while running the test task.", (Throwable)e);
            Assertions.fail((String)("Exception in Test: " + e.getMessage()));
        }
        ((ListAssert)Assertions.assertThat(this.outList).withFailMessage("Resultset size was %d. Expected was %d", new Object[]{this.outList.size(), 100})).hasSize(100);
        for (Record record : this.outList) {
            ((AbstractIntegerAssert)Assertions.assertThat((int)((IntValue)record.getField(1, IntValue.class)).getValue()).withFailMessage("Incorrect result", new Object[0])).isEqualTo(20 - ((IntValue)record.getField(0, IntValue.class)).getValue());
        }
        this.outList.clear();
    }

    @TestTemplate
    void testReduceTaskOnPreSortedInput() {
        int keyCnt = 100;
        int valCnt = 20;
        this.addInput(new UniformRecordGenerator(100, 20, true));
        this.addDriverComparator(this.comparator);
        this.setOutput(this.outList);
        this.getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
        GroupReduceDriver testTask = new GroupReduceDriver();
        try {
            this.testDriver((Driver)testTask, MockReduceStub.class);
        }
        catch (Exception e) {
            LOG.info("Exception while running the test task.", (Throwable)e);
            Assertions.fail((String)("Invoke method caused exception: " + e.getMessage()));
        }
        ((ListAssert)Assertions.assertThat(this.outList).withFailMessage("Resultset size was %d. Expected was %d", new Object[]{this.outList.size(), 100})).hasSize(100);
        for (Record record : this.outList) {
            ((AbstractIntegerAssert)Assertions.assertThat((int)((IntValue)record.getField(1, IntValue.class)).getValue()).withFailMessage("Incorrect result", new Object[0])).isEqualTo(20 - ((IntValue)record.getField(0, IntValue.class)).getValue());
        }
        this.outList.clear();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @TestTemplate
    void testCombiningReduceTask() throws IOException {
        int keyCnt = 100;
        int valCnt = 20;
        this.addDriverComparator(this.comparator);
        this.setOutput(this.outList);
        this.getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
        try (ExternalSorter sorter = null;){
            sorter = ExternalSorter.newBuilder((MemoryManager)this.getMemoryManager(), (AbstractInvokable)this.getContainingTask(), RecordSerializerFactory.get().getSerializer(), (TypeComparator)this.comparator.duplicate()).maxNumFileHandles(4).withCombiner((GroupCombineFunction)new ReduceTaskExternalITCase.MockCombiningReduceStub()).enableSpilling(this.getIOManager(), (double)0.8f).memoryFraction(this.perSortFractionMem).objectReuse(true).largeRecords(true).build((MutableObjectIterator)new UniformRecordGenerator(100, 20, false));
            this.addInput((MutableObjectIterator<Record>)sorter.getIterator());
            GroupReduceDriver testTask = new GroupReduceDriver();
            this.testDriver((Driver)testTask, MockCombiningReduceStub.class);
        }
        int expSum = 0;
        for (int i = 1; i < 20; ++i) {
            expSum += i;
        }
        ((ListAssert)Assertions.assertThat(this.outList).withFailMessage("Resultset size was %d. Expected was %d", new Object[]{this.outList.size(), 100})).hasSize(100);
        for (Record record : this.outList) {
            ((AbstractIntegerAssert)Assertions.assertThat((int)((IntValue)record.getField(1, IntValue.class)).getValue()).withFailMessage("Incorrect result", new Object[0])).isEqualTo(expSum - ((IntValue)record.getField(0, IntValue.class)).getValue());
        }
        this.outList.clear();
    }

    @TestTemplate
    void testFailingReduceTask() {
        int keyCnt = 100;
        int valCnt = 20;
        this.addInput(new UniformRecordGenerator(100, 20, true));
        this.addDriverComparator(this.comparator);
        this.setOutput(this.outList);
        this.getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
        GroupReduceDriver testTask = new GroupReduceDriver();
        Assertions.assertThatThrownBy(() -> this.testDriver((Driver)testTask, MockFailingReduceStub.class)).isInstanceOf(ExpectedTestException.class);
        this.outList.clear();
    }

    @TestTemplate
    void testCancelReduceTaskWhileSorting() {
        this.addDriverComparator(this.comparator);
        this.setOutput(new NirvanaOutputList());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
        final GroupReduceDriver testTask = new GroupReduceDriver();
        try {
            this.addInputSorted(new DelayingInfinitiveInputIterator(100), this.comparator.duplicate());
        }
        catch (Exception e) {
            e.printStackTrace();
            Assertions.fail((String)e.getMessage());
        }
        final AtomicBoolean success = new AtomicBoolean(false);
        Thread taskRunner = new Thread(){

            @Override
            public void run() {
                try {
                    ReduceTaskTest.this.testDriver((Driver)testTask, MockReduceStub.class);
                    success.set(true);
                }
                catch (Exception ie) {
                    ie.printStackTrace();
                }
            }
        };
        taskRunner.start();
        TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
        tct.start();
        try {
            tct.join();
            taskRunner.join();
        }
        catch (InterruptedException ie) {
            Assertions.fail((String)"Joining threads failed");
        }
        ((AtomicBooleanAssert)Assertions.assertThat((AtomicBoolean)success).withFailMessage("Test threw an exception even though it was properly canceled.", new Object[0])).isTrue();
    }

    @TestTemplate
    void testCancelReduceTaskWhileReducing() {
        int keyCnt = 1000;
        int valCnt = 2;
        this.addInput(new UniformRecordGenerator(1000, 2, true));
        this.addDriverComparator(this.comparator);
        this.setOutput(new NirvanaOutputList());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
        final GroupReduceDriver testTask = new GroupReduceDriver();
        final AtomicBoolean success = new AtomicBoolean(false);
        Thread taskRunner = new Thread(){

            @Override
            public void run() {
                try {
                    ReduceTaskTest.this.testDriver((Driver)testTask, MockDelayingReduceStub.class);
                    success.set(true);
                }
                catch (Exception ie) {
                    ie.printStackTrace();
                }
            }
        };
        taskRunner.start();
        TaskCancelThread tct = new TaskCancelThread(2, taskRunner, this);
        tct.start();
        try {
            tct.join();
            taskRunner.join();
        }
        catch (InterruptedException ie) {
            Assertions.fail((String)"Joining threads failed");
        }
    }

    public static class MockReduceStub
    extends RichGroupReduceFunction<Record, Record> {
        private static final long serialVersionUID = 1L;
        private final IntValue key = new IntValue();
        private final IntValue value = new IntValue();

        public void reduce(Iterable<Record> records, Collector<Record> out) {
            Record element = null;
            int cnt = 0;
            Iterator<Record> iterator = records.iterator();
            while (iterator.hasNext()) {
                Record next;
                element = next = iterator.next();
                ++cnt;
            }
            element.getField(0, (Value)this.key);
            this.value.setValue(cnt - this.key.getValue());
            element.setField(1, (Value)this.value);
            out.collect((Object)element);
        }
    }

    public static class MockCombiningReduceStub
    implements GroupReduceFunction<Record, Record>,
    GroupCombineFunction<Record, Record> {
        private static final long serialVersionUID = 1L;
        private final IntValue key = new IntValue();
        private final IntValue value = new IntValue();
        private final IntValue combineValue = new IntValue();

        public void reduce(Iterable<Record> records, Collector<Record> out) {
            Record element = null;
            int sum = 0;
            Iterator<Record> iterator = records.iterator();
            while (iterator.hasNext()) {
                Record next;
                element = next = iterator.next();
                element.getField(1, (Value)this.value);
                sum += this.value.getValue();
            }
            element.getField(0, (Value)this.key);
            this.value.setValue(sum - this.key.getValue());
            element.setField(1, (Value)this.value);
            out.collect((Object)element);
        }

        public void combine(Iterable<Record> records, Collector<Record> out) {
            Record element = null;
            int sum = 0;
            Iterator<Record> iterator = records.iterator();
            while (iterator.hasNext()) {
                Record next;
                element = next = iterator.next();
                element.getField(1, (Value)this.combineValue);
                sum += this.combineValue.getValue();
            }
            this.combineValue.setValue(sum);
            element.setField(1, (Value)this.combineValue);
            out.collect((Object)element);
        }
    }

    public static class MockFailingReduceStub
    extends RichGroupReduceFunction<Record, Record> {
        private static final long serialVersionUID = 1L;
        private int cnt = 0;
        private final IntValue key = new IntValue();
        private final IntValue value = new IntValue();

        public void reduce(Iterable<Record> records, Collector<Record> out) {
            Record element = null;
            int valCnt = 0;
            Iterator<Record> iterator = records.iterator();
            while (iterator.hasNext()) {
                Record next;
                element = next = iterator.next();
                ++valCnt;
            }
            if (++this.cnt >= 10) {
                throw new ExpectedTestException();
            }
            element.getField(0, (Value)this.key);
            this.value.setValue(valCnt - this.key.getValue());
            element.setField(1, (Value)this.value);
            out.collect((Object)element);
        }
    }

    public static class MockDelayingReduceStub
    extends RichGroupReduceFunction<Record, Record> {
        private static final long serialVersionUID = 1L;

        public void reduce(Iterable<Record> records, Collector<Record> out) {
            for (Record r : records) {
                try {
                    Thread.sleep(100L);
                }
                catch (InterruptedException interruptedException) {}
            }
        }
    }
}

