package org.neo4j.gds.ml.models.mlp;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.SplittableRandom;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.mlp.ImmutableMLPClassifierData;

@ValueClass
/* loaded from: input_file:org/neo4j/gds/ml/models/mlp/MLPClassifierData.class */
public interface MLPClassifierData extends Classifier.ClassifierData, Serializable {
    List<Weights<Matrix>> weights();

    List<Weights<Vector>> biases();

    @Value.Derived
    default int depth() {
        return biases().size() + 1;
    }

    @Override // org.neo4j.gds.ml.models.Classifier.ClassifierData
    @Value.Derived
    default int numberOfClasses() {
        return biases().get(biases().size() - 1).dimension(0);
    }

    @Override // org.neo4j.gds.ml.models.BaseModelData
    @Value.Derived
    default int featureDimension() {
        return weights().get(0).dimension(1);
    }

    @Override // org.neo4j.gds.ml.models.BaseModelData
    default TrainingMethod trainerMethod() {
        return TrainingMethod.MLPClassification;
    }

    static MLPClassifierData create(int i, int i2, List<Integer> list, SplittableRandom splittableRandom) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int size = list.size();
        arrayList.add(generateWeights(list.get(0).intValue(), i2, splittableRandom.nextLong()));
        arrayList2.add(generateBias(list.get(0).intValue(), splittableRandom.nextLong()));
        for (int i3 = 0; i3 < size - 1; i3++) {
            arrayList.add(generateWeights(list.get(i3 + 1).intValue(), list.get(i3).intValue(), splittableRandom.nextLong()));
            arrayList2.add(generateBias(list.get(i3 + 1).intValue(), splittableRandom.nextLong()));
        }
        arrayList.add(generateWeights(i, list.get(size - 1).intValue(), splittableRandom.nextLong()));
        arrayList2.add(generateBias(i, splittableRandom.nextLong()));
        return ImmutableMLPClassifierData.builder().weights(arrayList).biases(arrayList2).build();
    }

    private static Weights<Matrix> generateWeights(int i, int i2, long j) {
        double sqrt = Math.sqrt(2.0d / i2);
        return new Weights<>(new Matrix(new Random(j).doubles(Math.multiplyExact(i, i2), -sqrt, sqrt).toArray(), i, i2));
    }

    private static Weights<Vector> generateBias(int i, long j) {
        double sqrt = Math.sqrt(2.0d / i);
        return new Weights<>(new Vector(new Random(j).doubles(i, -sqrt, sqrt).toArray()));
    }

    static ImmutableMLPClassifierData.Builder builder() {
        return ImmutableMLPClassifierData.builder();
    }
}
