/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.clustering;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeansSuite;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.junit.Test;

public class JavaKMeansSuite
extends SharedSparkSession {
    private transient int k = 5;
    private transient Dataset<Row> dataset;

    @Override
    public void setUp() throws IOException {
        super.setUp();
        this.dataset = KMeansSuite.generateKMeansData(this.spark, 50, 3, this.k);
    }

    @Test
    public void fitAndTransform() {
        KMeans kmeans = new KMeans().setK(this.k).setSeed(1L);
        KMeansModel model = kmeans.fit(this.dataset);
        Vector[] centers = model.clusterCenters();
        Assert.assertEquals((long)this.k, (long)centers.length);
        Dataset transformed = model.transform(this.dataset);
        List<String> columns = Arrays.asList(transformed.columns());
        List<String> expectedColumns = Arrays.asList("features", "prediction");
        for (String column : expectedColumns) {
            Assert.assertTrue((boolean)columns.contains(column));
        }
    }
}

