package org.apache.spark.ml.evaluation;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.util.MetadataUtils$;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.immutable.Map;
import scala.collection.immutable.MapLike;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: ClusteringMetrics.scala */
/* loaded from: input_file:org/apache/spark/ml/evaluation/CosineSilhouette$.class */
public final class CosineSilhouette$ extends Silhouette {
    public static CosineSilhouette$ MODULE$;
    private final String normalizedFeaturesColName;

    static {
        new CosineSilhouette$();
    }

    public Map<Object, Tuple2<Vector, Object>> computeClusterStats(Dataset<Row> dataset, String str, String str2, String str3) {
        int numFeatures = MetadataUtils$.MODULE$.getNumFeatures(dataset, str);
        return RDD$.MODULE$.rddToPairRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str2).cast(DoubleType$.MODULE$), functions$.MODULE$.col(this.normalizedFeaturesColName), functions$.MODULE$.col(str3)})).rdd().map(row -> {
            return new Tuple2(BoxesRunTime.boxToDouble(row.getDouble(0)), new Tuple2(row.getAs(1), BoxesRunTime.boxToDouble(row.getDouble(2))));
        }, ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Double$.MODULE$).aggregateByKey(new Tuple2((Object) null, BoxesRunTime.boxToDouble(0.0d)), (tuple2, tuple22) -> {
            Tuple2 tuple2 = new Tuple2(tuple2, tuple22);
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2) tuple2._1();
                Tuple2 tuple23 = (Tuple2) tuple2._2();
                if (tuple22 != null) {
                    DenseVector denseVector = (DenseVector) tuple22._1();
                    double _2$mcD$sp = tuple22._2$mcD$sp();
                    if (tuple23 != null) {
                        Tuple4 tuple4 = new Tuple4(denseVector, BoxesRunTime.boxToDouble(_2$mcD$sp), (Vector) tuple23._1(), BoxesRunTime.boxToDouble(tuple23._2$mcD$sp()));
                        DenseVector denseVector2 = (DenseVector) tuple4._1();
                        double unboxToDouble = BoxesRunTime.unboxToDouble(tuple4._2());
                        Vector vector = (Vector) tuple4._3();
                        double unboxToDouble2 = BoxesRunTime.unboxToDouble(tuple4._4());
                        DenseVector dense = denseVector2 == null ? Vectors$.MODULE$.zeros(numFeatures).toDense() : denseVector2;
                        BLAS$.MODULE$.axpy(unboxToDouble2, vector, dense);
                        return new Tuple2(dense, BoxesRunTime.boxToDouble(unboxToDouble + unboxToDouble2));
                    }
                }
            }
            throw new MatchError(tuple2);
        }, (tuple23, tuple24) -> {
            DenseVector denseVector;
            Tuple2 tuple23 = new Tuple2(tuple23, tuple24);
            if (tuple23 == null) {
                throw new MatchError(tuple23);
            }
            if (tuple23 != null) {
                Tuple2 tuple24 = (Tuple2) tuple23._1();
                Tuple2 tuple25 = (Tuple2) tuple23._2();
                if (tuple24 != null) {
                    DenseVector denseVector2 = (DenseVector) tuple24._1();
                    double _2$mcD$sp = tuple24._2$mcD$sp();
                    if (tuple25 != null) {
                        Tuple4 tuple4 = new Tuple4(denseVector2, BoxesRunTime.boxToDouble(_2$mcD$sp), (DenseVector) tuple25._1(), BoxesRunTime.boxToDouble(tuple25._2$mcD$sp()));
                        DenseVector denseVector3 = (DenseVector) tuple4._1();
                        double unboxToDouble = BoxesRunTime.unboxToDouble(tuple4._2());
                        DenseVector denseVector4 = (DenseVector) tuple4._3();
                        double unboxToDouble2 = BoxesRunTime.unboxToDouble(tuple4._4());
                        if (denseVector3 == null) {
                            denseVector = denseVector4;
                        } else if (denseVector4 == null) {
                            denseVector = denseVector3;
                        } else {
                            BLAS$.MODULE$.axpy(1.0d, denseVector4, denseVector3);
                            denseVector = denseVector3;
                        }
                        return new Tuple2(denseVector, BoxesRunTime.boxToDouble(unboxToDouble + unboxToDouble2));
                    }
                }
            }
            throw new MatchError(tuple23);
        }, ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$Double$.MODULE$).collectAsMap().toMap(Predef$.MODULE$.$conforms());
    }

    public double computeSilhouetteCoefficient(Broadcast<Map<Object, Tuple2<Vector, Object>>> broadcast, Vector vector, double d, double d2) {
        return pointSilhouetteCoefficient(((MapLike) broadcast.value()).keySet(), d, ((Tuple2) ((scala.collection.MapLike) broadcast.value()).apply(BoxesRunTime.boxToDouble(d)))._2$mcD$sp(), d2, d3 -> {
            return compute$2(d3, broadcast, vector);
        });
    }

    public double computeSilhouetteScore(Dataset<?> dataset, String str, String str2, String str3) {
        Dataset<Row> withColumn = dataset.withColumn(this.normalizedFeaturesColName, functions$.MODULE$.udf(vector -> {
            BLAS$.MODULE$.scal(1.0d / Vectors$.MODULE$.norm(vector, 2.0d), vector);
            return vector;
        }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.evaluation.CosineSilhouette$$typecreator1$2
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.evaluation.CosineSilhouette$$typecreator2$2
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str2)})));
        Map<Object, Tuple2<Vector, Object>> computeClusterStats = computeClusterStats(withColumn, str2, str, str3);
        Predef$.MODULE$.assert(computeClusterStats.size() > 1, () -> {
            return "Number of clusters must be greater than one.";
        });
        Broadcast broadcast = dataset.sparkSession().sparkContext().broadcast(computeClusterStats, ClassTag$.MODULE$.apply(Map.class));
        double overallScore = overallScore(withColumn, functions$.MODULE$.udf((vector2, obj, obj2) -> {
            return BoxesRunTime.boxToDouble($anonfun$computeSilhouetteScore$6(broadcast, vector2, BoxesRunTime.unboxToDouble(obj), BoxesRunTime.unboxToDouble(obj2)));
        }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.evaluation.CosineSilhouette$$typecreator3$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }
        }), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Double()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(this.normalizedFeaturesColName), functions$.MODULE$.col(str).cast(DoubleType$.MODULE$), functions$.MODULE$.col(str3)})), functions$.MODULE$.col(str3));
        broadcast.destroy();
        return overallScore;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final double compute$2(double d, Broadcast broadcast, Vector vector) {
        Tuple2 tuple2 = (Tuple2) ((scala.collection.MapLike) broadcast.value()).apply(BoxesRunTime.boxToDouble(d));
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Vector) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()));
        return 1 - (BLAS$.MODULE$.dot(vector, (Vector) tuple22._1()) / tuple22._2$mcD$sp());
    }

    public static final /* synthetic */ double $anonfun$computeSilhouetteScore$6(Broadcast broadcast, Vector vector, double d, double d2) {
        return MODULE$.computeSilhouetteCoefficient(broadcast, vector, d, d2);
    }

    private CosineSilhouette$() {
        MODULE$ = this;
        this.normalizedFeaturesColName = "normalizedFeatures";
    }
}
