package org.bigml.mimir.deepnet.network;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.deepnet.layers.Layer;
import org.bigml.mimir.math.Matrices;
import org.bigml.mimir.math.Vectors;
import org.bigml.mimir.utils.Json;
import org.bigml.mimir.utils.fields.FieldCollection;

/* loaded from: input_file:org/bigml/mimir/deepnet/network/Network.class */
public class Network extends NetworkPredictor {
    private final FieldCollection _preprocessor;
    private final Embedding _embedding;
    private final Layer[] _layers;
    private final boolean _destandardize;
    private final double _yMean;
    private final double _yStDev;
    private final boolean _useTopLevelEmbedding;
    private static final long serialVersionUID = 1;

    public static Network createNetwork(JsonNode jsonNode, FieldCollection fieldCollection) {
        Layer[] makeLayers = Layer.makeLayers(jsonNode.get("layers"));
        Embedding createEmbedding = Embedding.createEmbedding(jsonNode.get("trees"));
        JsonNode jsonNode2 = jsonNode.get("output_exposition");
        return new Network(makeLayers, fieldCollection, createEmbedding, Json.getDoubleOrNaN(jsonNode2, "mean"), Json.getDoubleOrNaN(jsonNode2, "stdev"));
    }

    public Network(Layer[] layerArr, FieldCollection fieldCollection, Embedding embedding, double d, double d2, boolean z) {
        this._preprocessor = fieldCollection;
        this._embedding = embedding;
        this._useTopLevelEmbedding = z;
        this._layers = layerArr;
        this._yMean = d;
        this._yStDev = d2;
        if (Double.isNaN(d) && Double.isNaN(d2)) {
            this._destandardize = false;
        } else {
            this._destandardize = true;
        }
    }

    public Network(Layer[] layerArr, FieldCollection fieldCollection, Embedding embedding, double d, double d2) {
        this(layerArr, fieldCollection, embedding, d, d2, false);
    }

    public Network(Layer[] layerArr, double d, double d2, boolean z) {
        this(layerArr, null, null, d, d2, z);
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor, org.bigml.mimir.Predictor
    public double[] predict(double[] dArr) {
        double[] dArr2 = dArr;
        if (this._embedding != null) {
            dArr2 = Vectors.concat(this._embedding.embed(dArr), dArr);
        }
        double[] dArr3 = Matrices.toDouble(Layer.propagate(this._layers, Matrices.toFloat(dArr2)));
        if (this._destandardize) {
            dArr3[0] = (dArr3[0] * this._yStDev) + this._yMean;
        }
        return dArr3;
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor, org.bigml.mimir.Predictor
    public double[] predict(List<Object> list) {
        return predict(this._preprocessor.toDoubles(list));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor, org.bigml.mimir.Predictor
    public double[] predict(Map<String, Object> map) {
        return predict(this._preprocessor.toDoubles(map));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public double[][] predictions(double[] dArr) {
        double[] predict = predict(dArr);
        double[][] dArr2 = new double[predict.length][1];
        for (int i = 0; i < predict.length; i++) {
            dArr2[i][0] = predict[i];
        }
        return dArr2;
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public double[][] predictions(List<Object> list) {
        return predictions(this._preprocessor.toDoubles(list));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public double[][] predictions(Map<String, Object> map) {
        return predictions(this._preprocessor.toDoubles(map));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public FieldCollection getFieldCollection() {
        return this._preprocessor;
    }

    public boolean useEmbedding() {
        return this._useTopLevelEmbedding;
    }
}
