package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jersey.repackaged.com.google.common.collect.Lists;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBinarySetFunctions;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.LongWritable;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions.class */
public class TestGenericUDAFBinarySetFunctions {
    private List<Object[]> rowSet;

    /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions$GenericUDAFExecutor.class */
    public static class GenericUDAFExecutor {
        private GenericUDAFResolver2 evaluatorFactory;
        private GenericUDAFParameterInfo info;
        private ObjectInspector[] partialOIs;

        public GenericUDAFExecutor(GenericUDAFResolver2 genericUDAFResolver2, GenericUDAFParameterInfo genericUDAFParameterInfo) throws Exception {
            this.evaluatorFactory = genericUDAFResolver2;
            this.info = genericUDAFParameterInfo;
            this.partialOIs = new ObjectInspector[]{genericUDAFResolver2.getEvaluator(genericUDAFParameterInfo).init(GenericUDAFEvaluator.Mode.PARTIAL1, genericUDAFParameterInfo.getParameterObjectInspectors())};
        }

        List<Object> run(List<Object[]> list) throws Exception {
            return Lists.newArrayList(new Object[]{runComplete(list), runPartialFinal(list), runPartial2Final(list)});
        }

        private Object runComplete(List<Object[]> list) throws SemanticException, HiveException {
            GenericUDAFEvaluator evaluator = this.evaluatorFactory.getEvaluator(this.info);
            evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, this.info.getParameterObjectInspectors());
            GenericUDAFEvaluator.AggregationBuffer newAggregationBuffer = evaluator.getNewAggregationBuffer();
            Iterator<Object[]> it = list.iterator();
            while (it.hasNext()) {
                evaluator.iterate(newAggregationBuffer, it.next());
            }
            return evaluator.terminate(newAggregationBuffer);
        }

        private Object runPartialFinal(List<Object[]> list) throws Exception {
            GenericUDAFEvaluator evaluator = this.evaluatorFactory.getEvaluator(this.info);
            evaluator.init(GenericUDAFEvaluator.Mode.FINAL, this.partialOIs);
            GenericUDAFEvaluator.AggregationBuffer newAggregationBuffer = evaluator.getNewAggregationBuffer();
            Iterator<Object> it = runPartial1(list).iterator();
            while (it.hasNext()) {
                evaluator.merge(newAggregationBuffer, it.next());
            }
            return evaluator.terminate(newAggregationBuffer);
        }

        private Object runPartial2Final(List<Object[]> list) throws Exception {
            GenericUDAFEvaluator evaluator = this.evaluatorFactory.getEvaluator(this.info);
            evaluator.init(GenericUDAFEvaluator.Mode.FINAL, this.partialOIs);
            GenericUDAFEvaluator.AggregationBuffer newAggregationBuffer = evaluator.getNewAggregationBuffer();
            Iterator<Object> it = runPartial2(runPartial1(list)).iterator();
            while (it.hasNext()) {
                evaluator.merge(newAggregationBuffer, it.next());
            }
            return evaluator.terminate(newAggregationBuffer);
        }

        private List<Object> runPartial1(List<Object[]> list) throws Exception {
            ArrayList arrayList = new ArrayList();
            int i = 1;
            Iterator<Object[]> it = list.iterator();
            do {
                GenericUDAFEvaluator evaluator = this.evaluatorFactory.getEvaluator(this.info);
                evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, this.info.getParameterObjectInspectors());
                GenericUDAFEvaluator.AggregationBuffer newAggregationBuffer = evaluator.getNewAggregationBuffer();
                for (int i2 = 0; i2 < i - 1 && it.hasNext(); i2++) {
                    evaluator.iterate(newAggregationBuffer, it.next());
                }
                i <<= 1;
                arrayList.add(evaluator.terminatePartial(newAggregationBuffer));
            } while (it.hasNext());
            return arrayList;
        }

        private List<Object> runPartial2(List<Object> list) throws Exception {
            ArrayList arrayList = new ArrayList();
            int i = 1;
            Iterator<Object> it = list.iterator();
            do {
                GenericUDAFEvaluator evaluator = this.evaluatorFactory.getEvaluator(this.info);
                evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, this.partialOIs);
                GenericUDAFEvaluator.AggregationBuffer newAggregationBuffer = evaluator.getNewAggregationBuffer();
                for (int i2 = 0; i2 < i - 1 && it.hasNext(); i2++) {
                    evaluator.merge(newAggregationBuffer, it.next());
                }
                i <<= 1;
                arrayList.add(evaluator.terminatePartial(newAggregationBuffer));
            } while (it.hasNext());
            return arrayList;
        }
    }

    /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions$RegrIntermediate.class */
    static class RegrIntermediate {
        public double sum_x2;
        public double sum_y2;
        public double sum_x;
        public double sum_y;
        public double sum_xy;
        public double n;

        RegrIntermediate() {
        }

        public void add(Double d, Double d2) {
            if (d2 == null || d == null) {
                return;
            }
            this.sum_x2 += d2.doubleValue() * d2.doubleValue();
            this.sum_y2 += d.doubleValue() * d.doubleValue();
            this.sum_x += d2.doubleValue();
            this.sum_y += d.doubleValue();
            this.sum_xy += d2.doubleValue() * d.doubleValue();
            this.n += 1.0d;
        }

        public Double intercept() {
            double d = (this.n * this.sum_x2) - (this.sum_x * this.sum_x);
            if (this.n == 0.0d || d == 0.0d) {
                return null;
            }
            return Double.valueOf(((this.sum_y * this.sum_x2) - (this.sum_x * this.sum_xy)) / d);
        }

        public Double sxy() {
            if (this.n == 0.0d) {
                return null;
            }
            return Double.valueOf(this.sum_xy - ((this.sum_x * this.sum_y) / this.n));
        }

        public Double covar_pop() {
            if (this.n == 0.0d) {
                return null;
            }
            return Double.valueOf((this.sum_xy - ((this.sum_x * this.sum_y) / this.n)) / this.n);
        }

        public Double covar_samp() {
            if (this.n <= 1.0d) {
                return null;
            }
            return Double.valueOf((this.sum_xy - ((this.sum_x * this.sum_y) / this.n)) / (this.n - 1.0d));
        }

        public Double corr() {
            double d = (this.n * this.sum_x2) - (this.sum_x * this.sum_x);
            double d2 = (this.n * this.sum_y2) - (this.sum_y * this.sum_y);
            if (this.n == 0.0d || d == 0.0d || d2 == 0.0d) {
                return null;
            }
            double d3 = (this.n * this.sum_xy) - (this.sum_x * this.sum_y);
            return Double.valueOf(Math.sqrt(((d3 * d3) / d) / d2));
        }

        public Double r2() {
            double d = (this.n * this.sum_x2) - (this.sum_x * this.sum_x);
            double d2 = (this.n * this.sum_y2) - (this.sum_y * this.sum_y);
            if (this.n == 0.0d || d == 0.0d) {
                return null;
            }
            if (d2 == 0.0d) {
                return Double.valueOf(1.0d);
            }
            double d3 = (this.n * this.sum_xy) - (this.sum_x * this.sum_y);
            return Double.valueOf(((d3 * d3) / d) / d2);
        }

        public Double slope() {
            if (this.n == 0.0d || this.n * this.sum_x2 == this.sum_x * this.sum_x) {
                return null;
            }
            return Double.valueOf(((this.n * this.sum_xy) - (this.sum_x * this.sum_y)) / ((this.n * this.sum_x2) - (this.sum_x * this.sum_x)));
        }

        public Double avgx() {
            if (this.n == 0.0d) {
                return null;
            }
            return Double.valueOf(this.sum_x / this.n);
        }

        public Double avgy() {
            if (this.n == 0.0d) {
                return null;
            }
            return Double.valueOf(this.sum_y / this.n);
        }

        public Double count() {
            return Double.valueOf(this.n);
        }

        public Double sxx() {
            if (this.n == 0.0d) {
                return null;
            }
            return Double.valueOf(this.sum_x2 - ((this.sum_x * this.sum_x) / this.n));
        }

        public Double syy() {
            if (this.n == 0.0d) {
                return null;
            }
            return Double.valueOf(this.sum_y2 - ((this.sum_y * this.sum_y) / this.n));
        }

        public static RegrIntermediate computeFor(List<Object[]> list) {
            RegrIntermediate regrIntermediate = new RegrIntermediate();
            for (Object[] objArr : list) {
                regrIntermediate.add((Double) objArr[0], (Double) objArr[1]);
            }
            return regrIntermediate;
        }
    }

    /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions$RowSetGenerator.class */
    public static class RowSetGenerator {

        /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions$RowSetGenerator$ConstantSequence.class */
        public static class ConstantSequence implements FieldGenerator {
            private Object constant;

            public ConstantSequence(Object obj) {
                this.constant = obj;
            }

            @Override // org.apache.hadoop.hive.ql.udf.generic.TestGenericUDAFBinarySetFunctions.RowSetGenerator.FieldGenerator
            public Object apply(int i) {
                return this.constant;
            }
        }

        /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions$RowSetGenerator$DoubleSequence.class */
        public static class DoubleSequence implements FieldGenerator {
            private int offset;

            public DoubleSequence(int i) {
                this.offset = i;
            }

            @Override // org.apache.hadoop.hive.ql.udf.generic.TestGenericUDAFBinarySetFunctions.RowSetGenerator.FieldGenerator
            public Object apply(int i) {
                return Double.valueOf(i + this.offset);
            }
        }

        /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFBinarySetFunctions$RowSetGenerator$FieldGenerator.class */
        public interface FieldGenerator {
            Object apply(int i);
        }

        public static List<Object[]> generate(int i, FieldGenerator... fieldGeneratorArr) {
            ArrayList arrayList = new ArrayList(i);
            for (int i2 = 0; i2 < i; i2++) {
                ArrayList arrayList2 = new ArrayList();
                for (FieldGenerator fieldGenerator : fieldGeneratorArr) {
                    arrayList2.add(fieldGenerator.apply(i2));
                }
                arrayList.add(arrayList2.toArray());
            }
            return arrayList;
        }
    }

    @Parameterized.Parameters(name = "{0}")
    public static List<Object[]> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Object[]{"seq/seq", RowSetGenerator.generate(10, new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.DoubleSequence(0))});
        arrayList.add(new Object[]{"seq/ones", RowSetGenerator.generate(10, new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.ConstantSequence(Double.valueOf(1.0d)))});
        arrayList.add(new Object[]{"ones/seq", RowSetGenerator.generate(10, new RowSetGenerator.ConstantSequence(Double.valueOf(1.0d)), new RowSetGenerator.DoubleSequence(0))});
        arrayList.add(new Object[]{"empty", RowSetGenerator.generate(0, new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.DoubleSequence(0))});
        arrayList.add(new Object[]{"lonely", RowSetGenerator.generate(1, new RowSetGenerator.DoubleSequence(10), new RowSetGenerator.DoubleSequence(10))});
        arrayList.add(new Object[]{"seq/seq+10", RowSetGenerator.generate(10, new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.DoubleSequence(10))});
        arrayList.add(new Object[]{"seq/null", RowSetGenerator.generate(10, new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.ConstantSequence(null))});
        arrayList.add(new Object[]{"null/seq0", RowSetGenerator.generate(10, new RowSetGenerator.ConstantSequence(null), new RowSetGenerator.DoubleSequence(0))});
        return arrayList;
    }

    public TestGenericUDAFBinarySetFunctions(String str, List<Object[]> list) {
        this.rowSet = list;
    }

    @Test
    public void regr_count() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).count(), new GenericUDAFBinarySetFunctions.RegrCount());
    }

    @Test
    public void regr_sxx() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).sxx(), new GenericUDAFBinarySetFunctions.RegrSXX());
    }

    @Test
    public void regr_syy() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).syy(), new GenericUDAFBinarySetFunctions.RegrSYY());
    }

    @Test
    public void regr_sxy() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).sxy(), new GenericUDAFBinarySetFunctions.RegrSXY());
    }

    @Test
    public void regr_avgx() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).avgx(), new GenericUDAFBinarySetFunctions.RegrAvgX());
    }

    @Test
    public void regr_avgy() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).avgy(), new GenericUDAFBinarySetFunctions.RegrAvgY());
    }

    @Test
    public void regr_slope() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).slope(), new GenericUDAFBinarySetFunctions.RegrSlope());
    }

    @Test
    public void regr_r2() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).r2(), new GenericUDAFBinarySetFunctions.RegrR2());
    }

    @Test
    public void regr_intercept() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).intercept(), new GenericUDAFBinarySetFunctions.RegrIntercept());
    }

    @Test
    public void corr() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).corr(), new GenericUDAFCorrelation());
    }

    @Test
    public void covar_pop() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).covar_pop(), new GenericUDAFCovariance());
    }

    @Test
    public void covar_samp() throws Exception {
        validateUDAF(RegrIntermediate.computeFor(this.rowSet).covar_samp(), new GenericUDAFCovarianceSample());
    }

    private void validateUDAF(Double d, GenericUDAFResolver2 genericUDAFResolver2) throws Exception {
        List<Object> run = new GenericUDAFExecutor(genericUDAFResolver2, new SimpleGenericUDAFParameterInfo(new ObjectInspector[]{PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}, false, false, false)).run(this.rowSet);
        if (d == null) {
            Iterator<Object> it = run.iterator();
            while (it.hasNext()) {
                Assert.assertNull(it.next());
            }
        } else {
            for (Object obj : run) {
                if (obj instanceof DoubleWritable) {
                    Assert.assertEquals(d.doubleValue(), ((DoubleWritable) obj).get(), 1.0E-10d);
                } else {
                    Assert.assertEquals(d.doubleValue(), ((LongWritable) obj).get(), 1.0E-10d);
                }
            }
        }
    }
}
