package org.apache.spark.ml.evaluation;

import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.util.MetadataUtils$;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple6;
import scala.collection.immutable.Map;
import scala.collection.immutable.MapLike;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: ClusteringMetrics.scala */
/* loaded from: input_file:org/apache/spark/ml/evaluation/SquaredEuclideanSilhouette$.class */
public final class SquaredEuclideanSilhouette$ extends Silhouette {
    public static SquaredEuclideanSilhouette$ MODULE$;
    private boolean kryoRegistrationPerformed;

    static {
        new SquaredEuclideanSilhouette$();
    }

    public void registerKryoClasses(SparkContext sparkContext) {
        if (this.kryoRegistrationPerformed) {
            return;
        }
        sparkContext.getConf().registerKryoClasses(new Class[]{SquaredEuclideanSilhouette.ClusterStats.class});
        this.kryoRegistrationPerformed = true;
    }

    public Map<Object, SquaredEuclideanSilhouette.ClusterStats> computeClusterStats(Dataset<Row> dataset, String str, String str2, String str3) {
        int numFeatures = MetadataUtils$.MODULE$.getNumFeatures(dataset, str2);
        return RDD$.MODULE$.rddToPairRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str).cast(DoubleType$.MODULE$), functions$.MODULE$.col(str2), functions$.MODULE$.col("squaredNorm"), functions$.MODULE$.col(str3)})).rdd().map(row -> {
            return new Tuple2(BoxesRunTime.boxToDouble(row.getDouble(0)), new Tuple3(row.getAs(1), BoxesRunTime.boxToDouble(row.getDouble(2)), BoxesRunTime.boxToDouble(row.getDouble(3))));
        }, ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(Tuple3.class), Ordering$Double$.MODULE$).aggregateByKey(new Tuple3((Object) null, BoxesRunTime.boxToDouble(0.0d), BoxesRunTime.boxToDouble(0.0d)), (tuple3, tuple32) -> {
            Tuple2 tuple2 = new Tuple2(tuple3, tuple32);
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            if (tuple2 != null) {
                Tuple3 tuple3 = (Tuple3) tuple2._1();
                Tuple3 tuple32 = (Tuple3) tuple2._2();
                if (tuple3 != null) {
                    DenseVector denseVector = (DenseVector) tuple3._1();
                    double unboxToDouble = BoxesRunTime.unboxToDouble(tuple3._2());
                    double unboxToDouble2 = BoxesRunTime.unboxToDouble(tuple3._3());
                    if (tuple32 != null) {
                        Vector vector = (Vector) tuple32._1();
                        double unboxToDouble3 = BoxesRunTime.unboxToDouble(tuple32._2());
                        double unboxToDouble4 = BoxesRunTime.unboxToDouble(tuple32._3());
                        if (vector != null) {
                            Tuple6 tuple6 = new Tuple6(denseVector, BoxesRunTime.boxToDouble(unboxToDouble), BoxesRunTime.boxToDouble(unboxToDouble2), vector, BoxesRunTime.boxToDouble(unboxToDouble3), BoxesRunTime.boxToDouble(unboxToDouble4));
                            DenseVector denseVector2 = (DenseVector) tuple6._1();
                            double unboxToDouble5 = BoxesRunTime.unboxToDouble(tuple6._2());
                            double unboxToDouble6 = BoxesRunTime.unboxToDouble(tuple6._3());
                            Vector vector2 = (Vector) tuple6._4();
                            double unboxToDouble7 = BoxesRunTime.unboxToDouble(tuple6._5());
                            double unboxToDouble8 = BoxesRunTime.unboxToDouble(tuple6._6());
                            DenseVector dense = denseVector2 == null ? Vectors$.MODULE$.zeros(numFeatures).toDense() : denseVector2;
                            BLAS$.MODULE$.axpy(unboxToDouble8, vector2, dense);
                            return new Tuple3(dense, BoxesRunTime.boxToDouble(unboxToDouble5 + (unboxToDouble7 * unboxToDouble8)), BoxesRunTime.boxToDouble(unboxToDouble6 + unboxToDouble8));
                        }
                    }
                }
            }
            throw new MatchError(tuple2);
        }, (tuple33, tuple34) -> {
            DenseVector denseVector;
            Tuple2 tuple2 = new Tuple2(tuple33, tuple34);
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            if (tuple2 != null) {
                Tuple3 tuple33 = (Tuple3) tuple2._1();
                Tuple3 tuple34 = (Tuple3) tuple2._2();
                if (tuple33 != null) {
                    DenseVector denseVector2 = (DenseVector) tuple33._1();
                    double unboxToDouble = BoxesRunTime.unboxToDouble(tuple33._2());
                    double unboxToDouble2 = BoxesRunTime.unboxToDouble(tuple33._3());
                    if (tuple34 != null) {
                        Tuple6 tuple6 = new Tuple6(denseVector2, BoxesRunTime.boxToDouble(unboxToDouble), BoxesRunTime.boxToDouble(unboxToDouble2), (DenseVector) tuple34._1(), BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(tuple34._2())), BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(tuple34._3())));
                        DenseVector denseVector3 = (DenseVector) tuple6._1();
                        double unboxToDouble3 = BoxesRunTime.unboxToDouble(tuple6._2());
                        double unboxToDouble4 = BoxesRunTime.unboxToDouble(tuple6._3());
                        DenseVector denseVector4 = (DenseVector) tuple6._4();
                        double unboxToDouble5 = BoxesRunTime.unboxToDouble(tuple6._5());
                        double unboxToDouble6 = BoxesRunTime.unboxToDouble(tuple6._6());
                        if (denseVector3 == null) {
                            denseVector = denseVector4;
                        } else if (denseVector4 == null) {
                            denseVector = denseVector3;
                        } else {
                            BLAS$.MODULE$.axpy(1.0d, denseVector4, denseVector3);
                            denseVector = denseVector3;
                        }
                        return new Tuple3(denseVector, BoxesRunTime.boxToDouble(unboxToDouble3 + unboxToDouble5), BoxesRunTime.boxToDouble(unboxToDouble4 + unboxToDouble6));
                    }
                }
            }
            throw new MatchError(tuple2);
        }, ClassTag$.MODULE$.apply(Tuple3.class)), ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(Tuple3.class), Ordering$Double$.MODULE$).collectAsMap().mapValues(tuple35 -> {
            if (tuple35 != null) {
                DenseVector denseVector = (DenseVector) tuple35._1();
                double unboxToDouble = BoxesRunTime.unboxToDouble(tuple35._2());
                double unboxToDouble2 = BoxesRunTime.unboxToDouble(tuple35._3());
                if (denseVector != null) {
                    return new SquaredEuclideanSilhouette.ClusterStats(denseVector, unboxToDouble, unboxToDouble2);
                }
            }
            throw new MatchError(tuple35);
        }).toMap(Predef$.MODULE$.$conforms());
    }

    public double computeSilhouetteCoefficient(Broadcast<Map<Object, SquaredEuclideanSilhouette.ClusterStats>> broadcast, Vector vector, double d, double d2, double d3) {
        return pointSilhouetteCoefficient(((MapLike) broadcast.value()).keySet(), d, ((SquaredEuclideanSilhouette.ClusterStats) ((scala.collection.MapLike) broadcast.value()).apply(BoxesRunTime.boxToDouble(d))).weightSum(), d2, d4 -> {
            return compute$1(d4, broadcast, vector, d3);
        });
    }

    public double computeSilhouetteScore(Dataset<?> dataset, String str, String str2, String str3) {
        registerKryoClasses(dataset.sparkSession().sparkContext());
        Dataset<Row> withColumn = dataset.withColumn("squaredNorm", functions$.MODULE$.udf(vector -> {
            return BoxesRunTime.boxToDouble($anonfun$computeSilhouetteScore$1(vector));
        }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str2)})));
        Map<Object, SquaredEuclideanSilhouette.ClusterStats> computeClusterStats = computeClusterStats(withColumn, str, str2, str3);
        Predef$.MODULE$.assert(computeClusterStats.size() > 1, () -> {
            return "Number of clusters must be greater than one.";
        });
        Broadcast broadcast = dataset.sparkSession().sparkContext().broadcast(computeClusterStats, ClassTag$.MODULE$.apply(Map.class));
        double overallScore = overallScore(withColumn, functions$.MODULE$.udf((vector2, obj, obj2, obj3) -> {
            return BoxesRunTime.boxToDouble($anonfun$computeSilhouetteScore$3(broadcast, vector2, BoxesRunTime.unboxToDouble(obj), BoxesRunTime.unboxToDouble(obj2), BoxesRunTime.unboxToDouble(obj3)));
        }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette$$typecreator2$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        }), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Double()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str2), functions$.MODULE$.col(str).cast(DoubleType$.MODULE$), functions$.MODULE$.col(str3), functions$.MODULE$.col("squaredNorm")})), functions$.MODULE$.col(str3));
        broadcast.destroy();
        return overallScore;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final double compute$1(double d, Broadcast broadcast, Vector vector, double d2) {
        SquaredEuclideanSilhouette.ClusterStats clusterStats = (SquaredEuclideanSilhouette.ClusterStats) ((scala.collection.MapLike) broadcast.value()).apply(BoxesRunTime.boxToDouble(d));
        return (d2 + (clusterStats.squaredNormSum() / clusterStats.weightSum())) - ((2 * BLAS$.MODULE$.dot(vector, clusterStats.featureSum())) / clusterStats.weightSum());
    }

    public static final /* synthetic */ double $anonfun$computeSilhouetteScore$1(Vector vector) {
        return scala.math.package$.MODULE$.pow(Vectors$.MODULE$.norm(vector, 2.0d), 2.0d);
    }

    public static final /* synthetic */ double $anonfun$computeSilhouetteScore$3(Broadcast broadcast, Vector vector, double d, double d2, double d3) {
        return MODULE$.computeSilhouetteCoefficient(broadcast, vector, d, d2, d3);
    }

    private SquaredEuclideanSilhouette$() {
        MODULE$ = this;
        this.kryoRegistrationPerformed = false;
    }
}
