package org.neo4j.gds.embeddings.graphsage.algo;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.embeddings.graphsage.ActivationFunctionType;
import org.neo4j.gds.embeddings.graphsage.AggregatorType;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters.class */
public final class GraphSageTrainParameters extends Record {
    private final Concurrency concurrency;
    private final int batchSize;
    private final int maxIterations;
    private final int searchDepth;
    private final int epochs;
    private final double learningRate;
    private final double tolerance;
    private final int negativeSampleWeight;
    private final double penaltyL2;
    private final int embeddingDimension;
    private final List<Integer> sampleSizes;
    private final List<String> featureProperties;
    private final Optional<Double> maybeBatchSamplingRatio;
    private final Optional<Long> randomSeed;
    private final AggregatorType aggregatorType;
    private final ActivationFunctionType activationFunction;

    public GraphSageTrainParameters(Concurrency concurrency, int i, int i2, int i3, int i4, double d, double d2, int i5, double d3, int i6, List<Integer> list, List<String> list2, Optional<Double> optional, Optional<Long> optional2, AggregatorType aggregatorType, ActivationFunctionType activationFunctionType) {
        this.concurrency = concurrency;
        this.batchSize = i;
        this.maxIterations = i2;
        this.searchDepth = i3;
        this.epochs = i4;
        this.learningRate = d;
        this.tolerance = d2;
        this.negativeSampleWeight = i5;
        this.penaltyL2 = d3;
        this.embeddingDimension = i6;
        this.sampleSizes = list;
        this.featureProperties = list2;
        this.maybeBatchSamplingRatio = optional;
        this.randomSeed = optional2;
        this.aggregatorType = aggregatorType;
        this.activationFunction = activationFunctionType;
    }

    public long numberOfBatches(long j) {
        return (long) Math.ceil(j / this.batchSize);
    }

    public int batchesPerIteration(long j) {
        return (int) Math.ceil(maybeBatchSamplingRatio().orElse(Double.valueOf(Math.min(1.0d, (this.batchSize * this.concurrency.value()) / j))).doubleValue() * numberOfBatches(j));
    }

    public List<LayerConfig> layerConfigs(int i) {
        Random random = new Random();
        Optional<Long> optional = this.randomSeed;
        Objects.requireNonNull(random);
        optional.ifPresent((v1) -> {
            r1.setSeed(v1);
        });
        ArrayList arrayList = new ArrayList(this.sampleSizes.size());
        int i2 = 0;
        while (i2 < this.sampleSizes.size()) {
            arrayList.add(LayerConfig.builder().aggregatorType(this.aggregatorType).activationFunction(this.activationFunction).rows(this.embeddingDimension).cols(i2 == 0 ? i : this.embeddingDimension).sampleSize(this.sampleSizes.get(i2).intValue()).randomSeed(random.nextLong()).build());
            i2++;
        }
        return arrayList;
    }

    @Override // java.lang.Record
    public final String toString() {
        return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, GraphSageTrainParameters.class), GraphSageTrainParameters.class, "concurrency;batchSize;maxIterations;searchDepth;epochs;learningRate;tolerance;negativeSampleWeight;penaltyL2;embeddingDimension;sampleSizes;featureProperties;maybeBatchSamplingRatio;randomSeed;aggregatorType;activationFunction", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->concurrency:Lorg/neo4j/gds/core/concurrency/Concurrency;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->batchSize:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->maxIterations:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->searchDepth:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->epochs:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->learningRate:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->tolerance:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->negativeSampleWeight:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->penaltyL2:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->embeddingDimension:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->sampleSizes:Ljava/util/List;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->featureProperties:Ljava/util/List;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->maybeBatchSamplingRatio:Ljava/util/Optional;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->randomSeed:Ljava/util/Optional;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->aggregatorType:Lorg/neo4j/gds/embeddings/graphsage/AggregatorType;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->activationFunction:Lorg/neo4j/gds/embeddings/graphsage/ActivationFunctionType;").dynamicInvoker().invoke(this) /* invoke-custom */;
    }

    @Override // java.lang.Record
    public final int hashCode() {
        return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, GraphSageTrainParameters.class), GraphSageTrainParameters.class, "concurrency;batchSize;maxIterations;searchDepth;epochs;learningRate;tolerance;negativeSampleWeight;penaltyL2;embeddingDimension;sampleSizes;featureProperties;maybeBatchSamplingRatio;randomSeed;aggregatorType;activationFunction", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->concurrency:Lorg/neo4j/gds/core/concurrency/Concurrency;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->batchSize:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->maxIterations:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->searchDepth:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->epochs:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->learningRate:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->tolerance:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->negativeSampleWeight:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->penaltyL2:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->embeddingDimension:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->sampleSizes:Ljava/util/List;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->featureProperties:Ljava/util/List;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->maybeBatchSamplingRatio:Ljava/util/Optional;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->randomSeed:Ljava/util/Optional;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->aggregatorType:Lorg/neo4j/gds/embeddings/graphsage/AggregatorType;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->activationFunction:Lorg/neo4j/gds/embeddings/graphsage/ActivationFunctionType;").dynamicInvoker().invoke(this) /* invoke-custom */;
    }

    @Override // java.lang.Record
    public final boolean equals(Object obj) {
        return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, GraphSageTrainParameters.class, Object.class), GraphSageTrainParameters.class, "concurrency;batchSize;maxIterations;searchDepth;epochs;learningRate;tolerance;negativeSampleWeight;penaltyL2;embeddingDimension;sampleSizes;featureProperties;maybeBatchSamplingRatio;randomSeed;aggregatorType;activationFunction", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->concurrency:Lorg/neo4j/gds/core/concurrency/Concurrency;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->batchSize:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->maxIterations:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->searchDepth:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->epochs:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->learningRate:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->tolerance:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->negativeSampleWeight:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->penaltyL2:D", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->embeddingDimension:I", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->sampleSizes:Ljava/util/List;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->featureProperties:Ljava/util/List;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->maybeBatchSamplingRatio:Ljava/util/Optional;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->randomSeed:Ljava/util/Optional;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->aggregatorType:Lorg/neo4j/gds/embeddings/graphsage/AggregatorType;", "FIELD:Lorg/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainParameters;->activationFunction:Lorg/neo4j/gds/embeddings/graphsage/ActivationFunctionType;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
    }

    public Concurrency concurrency() {
        return this.concurrency;
    }

    public int batchSize() {
        return this.batchSize;
    }

    public int maxIterations() {
        return this.maxIterations;
    }

    public int searchDepth() {
        return this.searchDepth;
    }

    public int epochs() {
        return this.epochs;
    }

    public double learningRate() {
        return this.learningRate;
    }

    public double tolerance() {
        return this.tolerance;
    }

    public int negativeSampleWeight() {
        return this.negativeSampleWeight;
    }

    public double penaltyL2() {
        return this.penaltyL2;
    }

    public int embeddingDimension() {
        return this.embeddingDimension;
    }

    public List<Integer> sampleSizes() {
        return this.sampleSizes;
    }

    public List<String> featureProperties() {
        return this.featureProperties;
    }

    public Optional<Double> maybeBatchSamplingRatio() {
        return this.maybeBatchSamplingRatio;
    }

    public Optional<Long> randomSeed() {
        return this.randomSeed;
    }

    public AggregatorType aggregatorType() {
        return this.aggregatorType;
    }

    public ActivationFunctionType activationFunction() {
        return this.activationFunction;
    }
}
