package org.neo4j.gds.ml.models.automl;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.hyperparameter.DoubleRangeParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.IntegerRangeParameter;

/* loaded from: input_file:org/neo4j/gds/ml/models/automl/RandomSearch.class */
public class RandomSearch implements HyperParameterOptimizer {
    private final List<TunableTrainerConfig> concreteConfigs;
    private final List<TunableTrainerConfig> tunableConfigs;
    private final int totalNumberOfTrials;
    private final int numberOfConcreteTrials;
    private final SplittableRandom random;
    private int numberOfFinishedTrials;

    public RandomSearch(Map<TrainingMethod, List<TunableTrainerConfig>> map, int i, long j) {
        this(map, i, (Optional<Long>) Optional.of(Long.valueOf(j)));
    }

    public RandomSearch(Map<TrainingMethod, List<TunableTrainerConfig>> map, int i, Optional<Long> optional) {
        this.concreteConfigs = (List) map.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).filter((v0) -> {
            return v0.isConcrete();
        }).collect(Collectors.toList());
        this.tunableConfigs = (List) map.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).filter(tunableTrainerConfig -> {
            return !tunableTrainerConfig.isConcrete();
        }).collect(Collectors.toList());
        this.numberOfConcreteTrials = this.concreteConfigs.size();
        this.totalNumberOfTrials = i + this.numberOfConcreteTrials;
        this.random = (SplittableRandom) optional.map((v1) -> {
            return new SplittableRandom(v1);
        }).orElseGet(SplittableRandom::new);
        this.numberOfFinishedTrials = 0;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.numberOfFinishedTrials < this.numberOfConcreteTrials || (this.numberOfFinishedTrials < this.totalNumberOfTrials && !this.tunableConfigs.isEmpty());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public TrainerConfig next() {
        if (!hasNext()) {
            throw new IllegalStateException("RandomSearch has already exhausted the maximum trials or the parameter space.");
        }
        if (this.numberOfFinishedTrials >= this.concreteConfigs.size()) {
            this.numberOfFinishedTrials++;
            return sample(this.tunableConfigs.get(this.random.nextInt(this.tunableConfigs.size())));
        }
        List<TunableTrainerConfig> list = this.concreteConfigs;
        int i = this.numberOfFinishedTrials;
        this.numberOfFinishedTrials = i + 1;
        return list.get(i).materialize(Map.of());
    }

    private TrainerConfig sample(TunableTrainerConfig tunableTrainerConfig) {
        HashMap hashMap = new HashMap();
        tunableTrainerConfig.doubleRanges.forEach((str, doubleRangeParameter) -> {
            hashMap.put(str, Double.valueOf(sampleDouble(doubleRangeParameter)));
        });
        tunableTrainerConfig.integerRanges.forEach((str2, integerRangeParameter) -> {
            hashMap.put(str2, Integer.valueOf(sampleInteger(integerRangeParameter)));
        });
        return tunableTrainerConfig.materialize(hashMap);
    }

    private int sampleInteger(IntegerRangeParameter integerRangeParameter) {
        return this.random.nextInt(integerRangeParameter.min().intValue(), integerRangeParameter.max().intValue());
    }

    private double sampleDouble(DoubleRangeParameter doubleRangeParameter) {
        if (doubleRangeParameter.logScale()) {
            return Math.exp(this.random.nextDouble(doubleRangeParameter.min().doubleValue() < 1.0E-20d ? Math.log(1.0E-20d) : Math.log(doubleRangeParameter.min().doubleValue()), Math.log(doubleRangeParameter.max().doubleValue())));
        }
        return this.random.nextDouble(doubleRangeParameter.min().doubleValue(), doubleRangeParameter.max().doubleValue());
    }
}
