package de.jungblut.classification.nn;

import de.jungblut.classification.ClassifierFactory;
import de.jungblut.classification.eval.WeightMapper;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.minimize.DenseMatrixFolder;

/* loaded from: input_file:de/jungblut/classification/nn/MLPWeightMapper.class */
public class MLPWeightMapper implements WeightMapper<MultilayerPerceptron> {
    private final int[][] unfoldParameters;
    private final MultilayerPerceptron classifier;

    public MLPWeightMapper(ClassifierFactory<MultilayerPerceptron> classifierFactory) {
        this.classifier = classifierFactory.newInstance();
        this.unfoldParameters = MultilayerPerceptronCostFunction.computeUnfoldParameters(this.classifier.getLayers());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // de.jungblut.classification.eval.WeightMapper
    public MultilayerPerceptron mapWeights(DoubleVector doubleVector) {
        DoubleMatrix[] unfoldMatrices = DenseMatrixFolder.unfoldMatrices(doubleVector, this.unfoldParameters);
        for (int i = 0; i < unfoldMatrices.length; i++) {
            this.classifier.getWeights()[i].setWeights(unfoldMatrices[i]);
        }
        return this.classifier;
    }
}
