package org.bigml.mimir.deepnet.network;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.deepnet.layers.Layer;
import org.bigml.mimir.math.Vectors;
import org.bigml.mimir.utils.Json;
import org.bigml.mimir.utils.fields.FieldCollection;

/* loaded from: input_file:org/bigml/mimir/deepnet/network/Multinetwork.class */
public class Multinetwork extends NetworkPredictor {
    private final FieldCollection _preprocessor;
    private final Embedding _embedding;
    private final Network[] _networks;
    private final int _numberOfOutputs;
    private static final long serialVersionUID = 1;

    public Multinetwork(Network[] networkArr, FieldCollection fieldCollection, Embedding embedding, int i) {
        this._networks = networkArr;
        this._preprocessor = fieldCollection;
        this._embedding = embedding;
        this._numberOfOutputs = i;
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public double[][] predictions(double[] dArr) {
        double[] dArr2 = dArr;
        if (this._embedding != null) {
            dArr2 = Vectors.concat(this._embedding.embed(dArr), dArr);
        }
        double[][] dArr3 = new double[this._numberOfOutputs][this._networks.length];
        for (int i = 0; i < this._networks.length; i++) {
            double[] predict = this._networks[i].useEmbedding() ? this._networks[i].predict(dArr2) : this._networks[i].predict(dArr);
            for (int i2 = 0; i2 < predict.length; i2++) {
                dArr3[i2][i] = predict[i2];
            }
        }
        return dArr3;
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public double[][] predictions(List<Object> list) {
        return predictions(this._preprocessor.toDoubles(list));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public double[][] predictions(Map<String, Object> map) {
        return predictions(this._preprocessor.toDoubles(map));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor, org.bigml.mimir.Predictor
    public double[] predict(double[] dArr) {
        return NetworkPredictor.toProbabilities(predictions(dArr));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor, org.bigml.mimir.Predictor
    public double[] predict(List<Object> list) {
        return NetworkPredictor.toProbabilities(predictions(this._preprocessor.toDoubles(list)));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor, org.bigml.mimir.Predictor
    public double[] predict(Map<String, Object> map) {
        return NetworkPredictor.toProbabilities(predictions(this._preprocessor.toDoubles(map)));
    }

    @Override // org.bigml.mimir.deepnet.network.NetworkPredictor
    public FieldCollection getFieldCollection() {
        return this._preprocessor;
    }

    public static Multinetwork createMultiNetwork(JsonNode jsonNode, FieldCollection fieldCollection) {
        JsonNode jsonNode2;
        JsonNode jsonNode3 = jsonNode.get("networks");
        if (jsonNode.has("output_exposition")) {
            jsonNode2 = jsonNode.get("output_exposition");
        } else {
            if (!jsonNode3.get(0).has("output_exposition")) {
                throw new IllegalArgumentException("No output exposition found");
            }
            jsonNode2 = jsonNode3.get(0).get("output_exposition");
        }
        Embedding createEmbedding = Embedding.createEmbedding(jsonNode.get("trees"));
        Network[] networkArr = new Network[jsonNode3.size()];
        for (int i = 0; i < networkArr.length; i++) {
            networkArr[i] = createSubnet(jsonNode3.get(i), jsonNode2);
        }
        return new Multinetwork(networkArr, fieldCollection, createEmbedding, jsonNode2.has("values") ? jsonNode2.get("values").size() : 1);
    }

    private static Network createSubnet(JsonNode jsonNode, JsonNode jsonNode2) {
        return new Network(Layer.makeLayers(jsonNode.get("layers")), Json.getDoubleOrNaN(jsonNode2, "mean"), Json.getDoubleOrNaN(jsonNode2, "stdev"), jsonNode.get("trees").asBoolean());
    }
}
