package org.deeplearning4j.scaleout.perform;

import java.util.Collection;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.deeplearning4j.scaleout.job.Job;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/scaleout/perform/NeuralNetWorkPerformer.class */
public class NeuralNetWorkPerformer implements WorkerPerformer {
    protected Layer neuralNetwork;

    public void perform(Job job) {
        DataSet work = job.getWork();
        if (work instanceof DataSet) {
            this.neuralNetwork.fit(work.getFeatureMatrix());
        } else if (work instanceof INDArray) {
            this.neuralNetwork.fit((INDArray) work);
        }
        job.setResult(this.neuralNetwork.params());
    }

    public void update(Object... objArr) {
        this.neuralNetwork.setParams((INDArray) objArr[0]);
    }

    public void setup(Configuration configuration) {
        NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(configuration.get("org.deeplearning4j.scaleout.neuralnetconf"));
        this.neuralNetwork = LayerFactories.getFactory(fromJson.getLayer()).create(fromJson, (Collection) null, 0, Nd4j.create(1, LayerFactories.getFactory(fromJson.getLayer()).initializer().numParams(fromJson, true)), true);
    }
}
