package org.apache.spark.ml.clustering;

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/clustering/JavaKMeansSuite.class */
public class JavaKMeansSuite extends SharedSparkSession {
    private transient int k = 5;
    private transient Dataset<Row> dataset;

    @Override // org.apache.spark.SharedSparkSession
    public void setUp() throws IOException {
        super.setUp();
        this.dataset = KMeansSuite.generateKMeansData(this.spark, 50, 3, this.k);
    }

    @Test
    public void fitAndTransform() {
        KMeansModel fit = new KMeans().setK(this.k).setSeed(1L).fit(this.dataset);
        Assert.assertEquals(this.k, fit.clusterCenters().length);
        List asList = Arrays.asList(fit.transform(this.dataset).columns());
        Iterator it = Arrays.asList("features", "prediction").iterator();
        while (it.hasNext()) {
            Assert.assertTrue(asList.contains((String) it.next()));
        }
    }
}
