/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators;

import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.FlatMapDriver;
import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector;
import org.apache.flink.runtime.operators.testutils.DriverTestBase;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.InfiniteInputIterator;
import org.apache.flink.runtime.operators.testutils.TaskCancelThread;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.assertj.core.api.AbstractIntegerAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AtomicBooleanAssert;
import org.junit.jupiter.api.TestTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FlatMapTaskTest
extends DriverTestBase<FlatMapFunction<Record, Record>> {
    private static final Logger LOG = LoggerFactory.getLogger(FlatMapTaskTest.class);
    private final DriverTestBase.CountingOutputCollector output = new DriverTestBase.CountingOutputCollector();

    public FlatMapTaskTest(ExecutionConfig config) {
        super(config, 0L, 0);
    }

    @TestTemplate
    void testMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        this.addInput(new UniformRecordGenerator(100, 20, false));
        this.setOutput(this.output);
        FlatMapDriver testDriver = new FlatMapDriver();
        try {
            this.testDriver((Driver)testDriver, MockMapStub.class);
        }
        catch (Exception e) {
            LOG.debug("Exception while running the test driver.", (Throwable)e);
            Assertions.fail((String)"Invoke method caused exception.");
        }
        ((AbstractIntegerAssert)Assertions.assertThat((int)this.output.getNumberOfRecords()).withFailMessage("Wrong result set size.", new Object[0])).isEqualTo(2000);
    }

    @TestTemplate
    void testFailingMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        this.addInput(new UniformRecordGenerator(100, 20, false));
        this.setOutput(new DiscardingOutputCollector<Record>());
        FlatMapDriver testTask = new FlatMapDriver();
        Assertions.assertThatThrownBy(() -> this.testDriver((Driver)testTask, MockFailingMapStub.class)).isInstanceOf(ExpectedTestException.class);
    }

    @TestTemplate
    void testCancelMapTask() {
        this.addInput(new InfiniteInputIterator());
        this.setOutput(new DiscardingOutputCollector<Record>());
        final FlatMapDriver testTask = new FlatMapDriver();
        final AtomicBoolean success = new AtomicBoolean(false);
        Thread taskRunner = new Thread(){

            @Override
            public void run() {
                try {
                    FlatMapTaskTest.this.testDriver((Driver)testTask, MockMapStub.class);
                    success.set(true);
                }
                catch (Exception ie) {
                    ie.printStackTrace();
                }
            }
        };
        taskRunner.start();
        TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
        tct.start();
        try {
            tct.join();
            taskRunner.join();
        }
        catch (InterruptedException ie) {
            Assertions.fail((String)"Joining threads failed");
        }
        ((AtomicBooleanAssert)Assertions.assertThat((AtomicBoolean)success).withFailMessage("Test threw an exception even though it was properly canceled.", new Object[0])).isTrue();
    }

    public static class MockMapStub
    extends RichFlatMapFunction<Record, Record> {
        private static final long serialVersionUID = 1L;

        public void flatMap(Record record, Collector<Record> out) throws Exception {
            out.collect((Object)record);
        }
    }

    public static class MockFailingMapStub
    extends RichFlatMapFunction<Record, Record> {
        private static final long serialVersionUID = 1L;
        private int cnt = 0;

        public void flatMap(Record record, Collector<Record> out) throws Exception {
            if (++this.cnt >= 10) {
                throw new ExpectedTestException();
            }
            out.collect((Object)record);
        }
    }
}

