/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.plan;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public final class TestExecutionEngineWorkConcurrency {
    private final ExecutionEngineDagIdGenerator executionEngineDagIdGenerator;

    @Parameterized.Parameters
    public static Collection<Object[]> data() {
        return Arrays.asList({new TezDagIdProvider()}, {new SparkDagIdProvider()});
    }

    public TestExecutionEngineWorkConcurrency(ExecutionEngineDagIdGenerator executionEngineDagIdGenerator) {
        this.executionEngineDagIdGenerator = executionEngineDagIdGenerator;
    }

    @Test
    public void ensureDagIdIsUnique() throws Exception {
        int threadCount = 5;
        final CountDownLatch threadReadyToStartSignal = new CountDownLatch(5);
        final CountDownLatch startThreadSignal = new CountDownLatch(1);
        int numberOfWorkToCreatePerThread = 100;
        ArrayList tasks = Lists.newArrayList();
        for (int i = 0; i < 5; ++i) {
            tasks.add(new FutureTask<Set<String>>(new Callable<Set<String>>(){

                @Override
                public Set<String> call() throws Exception {
                    threadReadyToStartSignal.countDown();
                    startThreadSignal.await();
                    return TestExecutionEngineWorkConcurrency.this.generateWorkDagIds(100);
                }
            }));
        }
        ExecutorService executor = Executors.newFixedThreadPool(5);
        for (FutureTask task : tasks) {
            executor.execute(task);
        }
        threadReadyToStartSignal.await();
        startThreadSignal.countDown();
        Set<String> allWorkDagIds = TestExecutionEngineWorkConcurrency.getAllWorkDagIds(tasks);
        Assert.assertEquals((long)500L, (long)allWorkDagIds.size());
    }

    private Set<String> generateWorkDagIds(int numberOfNames) {
        HashSet workIds = Sets.newHashSet();
        for (int i = 0; i < numberOfNames; ++i) {
            workIds.add(this.executionEngineDagIdGenerator.getDagId());
        }
        return workIds;
    }

    private static Set<String> getAllWorkDagIds(List<FutureTask<Set<String>>> tasks) throws ExecutionException, InterruptedException {
        HashSet allWorkDagIds = Sets.newHashSet();
        for (FutureTask<Set<String>> task : tasks) {
            allWorkDagIds.addAll((Collection)task.get());
        }
        return allWorkDagIds;
    }

    private static final class SparkDagIdProvider
    implements ExecutionEngineDagIdGenerator {
        private SparkDagIdProvider() {
        }

        @Override
        public String getDagId() {
            return new SparkWork("query-id").getName();
        }
    }

    private static final class TezDagIdProvider
    implements ExecutionEngineDagIdGenerator {
        private TezDagIdProvider() {
        }

        @Override
        public String getDagId() {
            return new TezWork("query-id").getDagId();
        }
    }

    private static interface ExecutionEngineDagIdGenerator {
        public String getDagId();
    }
}

