package com.github.jelmerk.knn.spark;

import com.github.jelmerk.knn.scalalike.Index;
import com.github.jelmerk.knn.spark.KnnAlgorithmParams;
import com.github.jelmerk.knn.spark.KnnModelParams;
import org.apache.spark.Partitioner;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.Function2;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq$;
import scala.collection.immutable.List$;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;

/* compiled from: KnnAlgorithm.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005=e!B\u0001\u0003\u0003\u0003i!\u0001D&o]\u0006cwm\u001c:ji\"l'BA\u0002\u0005\u0003\u0015\u0019\b/\u0019:l\u0015\t)a!A\u0002l]:T!a\u0002\u0005\u0002\u000f),G.\\3sW*\u0011\u0011BC\u0001\u0007O&$\b.\u001e2\u000b\u0003-\t1aY8n\u0007\u0001)\"A\u0004\u000f\u0014\u0007\u0001y\u0001\u0006E\u0002\u00111ii\u0011!\u0005\u0006\u0003%M\t!!\u001c7\u000b\u0005\r!\"BA\u000b\u0017\u0003\u0019\t\u0007/Y2iK*\tq#A\u0002pe\u001eL!!G\t\u0003\u0013\u0015\u001bH/[7bi>\u0014\bCA\u000e\u001d\u0019\u0001!Q!\b\u0001C\u0002y\u0011a\u0001V'pI\u0016d\u0017CA\u0010&!\t\u00013%D\u0001\"\u0015\u0005\u0011\u0013!B:dC2\f\u0017B\u0001\u0013\"\u0005\u001dqu\u000e\u001e5j]\u001e\u00042\u0001\u0005\u0014\u001b\u0013\t9\u0013CA\u0003N_\u0012,G\u000e\u0005\u0002*U5\t!!\u0003\u0002,\u0005\t\u00112J\u001c8BY\u001e|'/\u001b;i[B\u000b'/Y7t\u0011!i\u0003A!b\u0001\n\u0003r\u0013aA;jIV\tq\u0006\u0005\u00021g9\u0011\u0001%M\u0005\u0003e\u0005\na\u0001\u0015:fI\u00164\u0017B\u0001\u001b6\u0005\u0019\u0019FO]5oO*\u0011!'\t\u0005\to\u0001\u0011\t\u0011)A\u0005_\u0005!Q/\u001b3!\u0011\u0015I\u0004\u0001\"\u0001;\u0003\u0019a\u0014N\\5u}Q\u00111\b\u0010\t\u0004S\u0001Q\u0002\"B\u00179\u0001\u0004y\u0003\"\u0002 \u0001\t\u0003y\u0014AD:fi&#WM\u001c;jif\u001cu\u000e\u001c\u000b\u0003\u0001\u0006k\u0011\u0001\u0001\u0005\u0006\u0005v\u0002\raL\u0001\u0006m\u0006dW/\u001a\u0005\u0006\t\u0002!\t!R\u0001\rg\u0016$h+Z2u_J\u001cu\u000e\u001c\u000b\u0003\u0001\u001aCQAQ\"A\u0002=BQ\u0001\u0013\u0001\u0005\u0002%\u000bqb]3u\u001d\u0016Lw\r\u001b2peN\u001cu\u000e\u001c\u000b\u0003\u0001*CQAQ$A\u0002=BQ\u0001\u0014\u0001\u0005\u00025\u000bAa]3u\u0017R\u0011\u0001I\u0014\u0005\u0006\u0005.\u0003\ra\u0014\t\u0003AAK!!U\u0011\u0003\u0007%sG\u000fC\u0003T\u0001\u0011\u0005A+\u0001\ttKRtU/\u001c)beRLG/[8ogR\u0011\u0001)\u0016\u0005\u0006\u0005J\u0003\ra\u0014\u0005\u0006/\u0002!\t\u0001W\u0001\u0014g\u0016$H)[:uC:\u001cWMR;oGRLwN\u001c\u000b\u0003\u0001fCQA\u0011,A\u0002=BQa\u0017\u0001\u0005Bq\u000b1AZ5u)\tQR\fC\u0003_5\u0002\u0007q,A\u0004eCR\f7/\u001a;1\u0005\u0001<\u0007cA1eM6\t!M\u0003\u0002d'\u0005\u00191/\u001d7\n\u0005\u0015\u0014'a\u0002#bi\u0006\u001cX\r\u001e\t\u00037\u001d$\u0011\u0002[/\u0002\u0002\u0003\u0005)\u0011A5\u0003\u0007}##'\u0005\u0002 UB\u0011\u0001e[\u0005\u0003Y\u0006\u00121!\u00118z\u0011\u0015q\u0007\u0001\"\u0011p\u0003=!(/\u00198tM>\u0014XnU2iK6\fGC\u00019w!\t\tH/D\u0001s\u0015\t\u0019(-A\u0003usB,7/\u0003\u0002ve\nQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u000b]l\u0007\u0019\u00019\u0002\rM\u001c\u0007.Z7b\u0011\u0015I\b\u0001\"\u0011{\u0003\u0011\u0019w\u000e]=\u0015\u0005=Y\b\"\u0002?y\u0001\u0004i\u0018!B3yiJ\f\u0007c\u0001@\u0002\u00045\tqPC\u0002\u0002\u0002E\tQ\u0001]1sC6L1!!\u0002��\u0005!\u0001\u0016M]1n\u001b\u0006\u0004\bbBA\u0005\u0001\u0019E\u00111B\u0001\fGJ,\u0017\r^3J]\u0012,\u0007\u0010\u0006\u0003\u0002\u000e\u0005-\u0002cCA\b\u0003+y\u0013\u0011DA\u0013\u0003?i!!!\u0005\u000b\u0007\u0005MA!A\u0005tG\u0006d\u0017\r\\5lK&!\u0011qCA\t\u0005\u0015Ie\u000eZ3y!\u0015\u0001\u00131DA\u0010\u0013\r\ti\"\t\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0004A\u0005\u0005\u0012bAA\u0012C\t)a\t\\8biB\u0019\u0011&a\n\n\u0007\u0005%\"AA\u0005J]\u0012,\u00070\u0013;f[\"9\u0011QFA\u0004\u0001\u0004y\u0015\u0001D7bq&#X-\\\"pk:$\bbBA\u0019\u0001\u0019E\u00111G\u0001\fGJ,\u0017\r^3N_\u0012,G\u000eF\u0005\u001b\u0003k\t9$a\u000f\u0002H!1Q&a\fA\u0002=Bq!!\u000f\u00020\u0001\u0007q*A\u0007ok6\u0004\u0016M\u001d;ji&|gn\u001d\u0005\t\u0003{\ty\u00031\u0001\u0002@\u0005Y\u0001/\u0019:uSRLwN\\3s!\u0011\t\t%a\u0011\u000e\u0003MI1!!\u0012\u0014\u0005-\u0001\u0016M\u001d;ji&|g.\u001a:\t\u0011\u0005%\u0013q\u0006a\u0001\u0003\u0017\nq!\u001b8eS\u000e,7\u000f\u0005\u0004\u0002N\u0005M\u0013qK\u0007\u0003\u0003\u001fR1!!\u0015\u0014\u0003\r\u0011H\rZ\u0005\u0005\u0003+\nyEA\u0002S\t\u0012\u0003b\u0001IA-\u001f\u00065\u0011bAA.C\t1A+\u001e9mKJBq!a\u0018\u0001\t#\t\t'\u0001\feSN$\u0018M\\2f\rVt7\r^5p]\nKh*Y7f)\u0011\t\u0019'a#\u0011\u0011\u0005\u0015\u0014QQA\r\u0003?qA!a\u001a\u0002\u0002:!\u0011\u0011NA@\u001d\u0011\tY'! \u000f\t\u00055\u00141\u0010\b\u0005\u0003_\nIH\u0004\u0003\u0002r\u0005]TBAA:\u0015\r\t)\bD\u0001\u0007yI|w\u000e\u001e \n\u0003-I!!\u0003\u0006\n\u0005\u001dA\u0011BA\u0003\u0007\u0013\r\t\u0019\u0002B\u0005\u0005\u0003\u0007\u000b\t\"A\u0004qC\u000e\\\u0017mZ3\n\t\u0005\u001d\u0015\u0011\u0012\u0002\u0011\t&\u001cH/\u00198dK\u001a+hn\u0019;j_:TA!a!\u0002\u0012!9\u0011QRA/\u0001\u0004y\u0013\u0001\u00028b[\u0016\u0004")
/* loaded from: input_file:com/github/jelmerk/knn/spark/KnnAlgorithm.class */
public abstract class KnnAlgorithm<TModel extends Model<TModel>> extends Estimator<TModel> implements KnnAlgorithmParams {
    private final String uid;
    private final IntParam numPartitions;
    private final Param<String> identifierCol;
    private final Param<String> vectorCol;
    private final Param<String> neighborsCol;
    private final Param<String> distanceFunction;
    private final IntParam k;

    @Override // com.github.jelmerk.knn.spark.KnnAlgorithmParams
    public IntParam numPartitions() {
        return this.numPartitions;
    }

    @Override // com.github.jelmerk.knn.spark.KnnAlgorithmParams
    public void com$github$jelmerk$knn$spark$KnnAlgorithmParams$_setter_$numPartitions_$eq(IntParam intParam) {
        this.numPartitions = intParam;
    }

    @Override // com.github.jelmerk.knn.spark.KnnAlgorithmParams
    public int getNumPartitions() {
        return KnnAlgorithmParams.Cclass.getNumPartitions(this);
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public Param<String> identifierCol() {
        return this.identifierCol;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public Param<String> vectorCol() {
        return this.vectorCol;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public Param<String> neighborsCol() {
        return this.neighborsCol;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public Param<String> distanceFunction() {
        return this.distanceFunction;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public IntParam k() {
        return this.k;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public void com$github$jelmerk$knn$spark$KnnModelParams$_setter_$identifierCol_$eq(Param param) {
        this.identifierCol = param;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public void com$github$jelmerk$knn$spark$KnnModelParams$_setter_$vectorCol_$eq(Param param) {
        this.vectorCol = param;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public void com$github$jelmerk$knn$spark$KnnModelParams$_setter_$neighborsCol_$eq(Param param) {
        this.neighborsCol = param;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public void com$github$jelmerk$knn$spark$KnnModelParams$_setter_$distanceFunction_$eq(Param param) {
        this.distanceFunction = param;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public void com$github$jelmerk$knn$spark$KnnModelParams$_setter_$k_$eq(IntParam intParam) {
        this.k = intParam;
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public String getIdentifierCol() {
        return KnnModelParams.Cclass.getIdentifierCol(this);
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public String getVectorCol() {
        return KnnModelParams.Cclass.getVectorCol(this);
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public String getNeighborsCol() {
        return KnnModelParams.Cclass.getNeighborsCol(this);
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public String getDistanceFunction() {
        return KnnModelParams.Cclass.getDistanceFunction(this);
    }

    @Override // com.github.jelmerk.knn.spark.KnnModelParams
    public int getK() {
        return KnnModelParams.Cclass.getK(this);
    }

    public String uid() {
        return this.uid;
    }

    public KnnAlgorithm<TModel> setIdentityCol(String str) {
        return (KnnAlgorithm) set(identifierCol(), str);
    }

    public KnnAlgorithm<TModel> setVectorCol(String str) {
        return (KnnAlgorithm) set(vectorCol(), str);
    }

    public KnnAlgorithm<TModel> setNeighborsCol(String str) {
        return (KnnAlgorithm) set(neighborsCol(), str);
    }

    public KnnAlgorithm<TModel> setK(int i) {
        return (KnnAlgorithm) set(k(), BoxesRunTime.boxToInteger(i));
    }

    public KnnAlgorithm<TModel> setNumPartitions(int i) {
        return (KnnAlgorithm) set(numPartitions(), BoxesRunTime.boxToInteger(i));
    }

    public KnnAlgorithm<TModel> setDistanceFunction(String str) {
        return (KnnAlgorithm) set(distanceFunction(), str);
    }

    public TModel fit(Dataset<?> dataset) {
        Column apply;
        DataType dataType = dataset.schema().apply(getVectorCol()).dataType();
        if (dataType != null) {
            String typeName = dataType.typeName();
            if (typeName != null ? typeName.equals("vector") : "vector" == 0) {
                apply = Udfs$.MODULE$.vectorToFloatArray().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getVectorCol())}));
                PartitionIdPassthrough partitionIdPassthrough = new PartitionIdPassthrough(getNumPartitions());
                return copyValues(createModel(uid(), getNumPartitions(), partitionIdPassthrough, RDD$.MODULE$.rddToPairRDDFunctions(dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getIdentifierCol()).cast(StringType$.MODULE$).as("id"), apply.as("vector")})).as(dataset.sparkSession().implicits().newProductEncoder(((TypeTags) package$.MODULE$.universe()).TypeTag().apply((Mirror) package$.MODULE$.universe().runtimeMirror(KnnAlgorithm.class.getClassLoader()), new TypeCreator(this) { // from class: com.github.jelmerk.knn.spark.KnnAlgorithm$$typecreator20$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("com.github.jelmerk.knn.spark.IndexItem").asType().toTypeConstructor();
                    }
                }))).map(new KnnAlgorithm$$anonfun$7(this), dataset.sparkSession().implicits().newProductEncoder(((TypeTags) package$.MODULE$.universe()).TypeTag().apply((Mirror) package$.MODULE$.universe().runtimeMirror(KnnAlgorithm.class.getClassLoader()), new TypeCreator(this) { // from class: com.github.jelmerk.knn.spark.KnnAlgorithm$$typecreator21$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        Universe universe = mirror.universe();
                        return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple2"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("com.github.jelmerk.knn.spark.IndexItem").asType().toTypeConstructor()})));
                    }
                }))).rdd(), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(IndexItem.class), Ordering$Int$.MODULE$).partitionBy(partitionIdPassthrough).mapPartitions(new KnnAlgorithm$$anonfun$8(this), true, ClassTag$.MODULE$.apply(Tuple2.class))), copyValues$default$2());
            }
        }
        apply = ((dataType instanceof ArrayType) && DoubleType$.MODULE$.equals(((ArrayType) dataType).elementType())) ? Udfs$.MODULE$.doubleArrayToFloatArray().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getVectorCol())})) : functions$.MODULE$.col(getVectorCol());
        PartitionIdPassthrough partitionIdPassthrough2 = new PartitionIdPassthrough(getNumPartitions());
        return copyValues(createModel(uid(), getNumPartitions(), partitionIdPassthrough2, RDD$.MODULE$.rddToPairRDDFunctions(dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getIdentifierCol()).cast(StringType$.MODULE$).as("id"), apply.as("vector")})).as(dataset.sparkSession().implicits().newProductEncoder(((TypeTags) package$.MODULE$.universe()).TypeTag().apply((Mirror) package$.MODULE$.universe().runtimeMirror(KnnAlgorithm.class.getClassLoader()), new TypeCreator(this) { // from class: com.github.jelmerk.knn.spark.KnnAlgorithm$$typecreator20$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("com.github.jelmerk.knn.spark.IndexItem").asType().toTypeConstructor();
            }
        }))).map(new KnnAlgorithm$$anonfun$7(this), dataset.sparkSession().implicits().newProductEncoder(((TypeTags) package$.MODULE$.universe()).TypeTag().apply((Mirror) package$.MODULE$.universe().runtimeMirror(KnnAlgorithm.class.getClassLoader()), new TypeCreator(this) { // from class: com.github.jelmerk.knn.spark.KnnAlgorithm$$typecreator21$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple2"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("com.github.jelmerk.knn.spark.IndexItem").asType().toTypeConstructor()})));
            }
        }))).rdd(), ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.apply(IndexItem.class), Ordering$Int$.MODULE$).partitionBy(partitionIdPassthrough2).mapPartitions(new KnnAlgorithm$$anonfun$8(this), true, ClassTag$.MODULE$.apply(Tuple2.class))), copyValues$default$2());
    }

    public StructType transformSchema(StructType structType) {
        DataType dataType = structType.apply(getIdentifierCol()).dataType();
        return StructType$.MODULE$.apply(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new StructField[]{new StructField(getIdentifierCol(), dataType, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new StructField(getNeighborsCol(), ArrayType$.MODULE$.apply(StructType$.MODULE$.apply(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new StructField[]{new StructField("neighbor", dataType, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), new StructField("distance", FloatType$.MODULE$, StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4())})))), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4())})));
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public Estimator<TModel> m4copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public abstract Index<String, float[], IndexItem, Object> createIndex(int i);

    public abstract TModel createModel(String str, int i, Partitioner partitioner, RDD<Tuple2<Object, Index<String, float[], IndexItem, Object>>> rdd);

    public Function2<float[], float[], Object> distanceFunctionByName(String str) {
        Function2<float[], float[], Object> floatInnerProduct;
        if ("cosine".equals(str)) {
            floatInnerProduct = com.github.jelmerk.knn.scalalike.package$.MODULE$.floatCosineDistance();
        } else {
            if (!"inner-product".equals(str)) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " is not a valid distance function."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{getDistanceFunction()})));
            }
            floatInnerProduct = com.github.jelmerk.knn.scalalike.package$.MODULE$.floatInnerProduct();
        }
        return floatInnerProduct;
    }

    public KnnAlgorithm(String str) {
        this.uid = str;
        KnnModelParams.Cclass.$init$(this);
        com$github$jelmerk$knn$spark$KnnAlgorithmParams$_setter_$numPartitions_$eq(new IntParam(this, "numPartitions", "number of partitions", ParamValidators$.MODULE$.gt(0.0d)));
    }
}
