/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.stat;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.stat.Summarizer;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.junit.Assert;
import org.junit.Test;

public class JavaSummarizerSuite
extends SharedSparkSession {
    private transient Dataset<Row> dataset;

    @Override
    public void setUp() throws IOException {
        super.setUp();
        ArrayList<LabeledPoint> points = new ArrayList<LabeledPoint>();
        points.add(new LabeledPoint(0.0, Vectors.dense((double)1.0, (double[])new double[]{2.0})));
        points.add(new LabeledPoint(0.0, Vectors.dense((double)3.0, (double[])new double[]{4.0})));
        this.dataset = this.spark.createDataFrame(this.jsc.parallelize(points, 2), LabeledPoint.class);
    }

    @Test
    public void testSummarizer() {
        this.dataset.select(new Column[]{functions.col((String)"features")});
        Row result = ((Row)this.dataset.select(new Column[]{Summarizer.metrics((String[])new String[]{"mean", "max", "count"}).summary(functions.col((String)"features"))}).first()).getStruct(0);
        Vector meanVec = (Vector)result.getAs("mean");
        Vector maxVec = (Vector)result.getAs("max");
        long count = (Long)result.getAs("count");
        Assert.assertEquals((long)2L, (long)count);
        Assert.assertArrayEquals((double[])new double[]{2.0, 3.0}, (double[])meanVec.toArray(), (double)0.0);
        Assert.assertArrayEquals((double[])new double[]{3.0, 4.0}, (double[])maxVec.toArray(), (double)0.0);
    }
}

