package org.bigml.mimir.deepnet.network;

import com.fasterxml.jackson.databind.JsonNode;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.bigml.mimir.utils.Json;

/* loaded from: input_file:org/bigml/mimir/deepnet/network/Embedding.class */
public class Embedding implements Serializable {
    private final Node[][] _nodes;
    private final int _nclasses;
    private final int _outputSize;
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:org/bigml/mimir/deepnet/network/Embedding$Node.class */
    public static class Node implements Serializable {
        private final int _splitIndex;
        private final double _splitValue;
        private final Node _left;
        private final Node _right;
        private final double[] _coordinates;
        private static final long serialVersionUID = 1;

        public Node(int i, double d, Node node, Node node2) {
            this._splitIndex = i;
            this._splitValue = d;
            this._left = node;
            this._right = node2;
            this._coordinates = null;
        }

        public Node(double[] dArr) {
            this._coordinates = dArr;
            this._right = null;
            this._left = null;
            this._splitIndex = -1;
            this._splitValue = Double.NaN;
        }
    }

    public static Embedding createEmbedding(JsonNode jsonNode) {
        if (jsonNode.isNull()) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < jsonNode.size(); i++) {
            ArrayList arrayList2 = new ArrayList();
            JsonNode jsonNode2 = jsonNode.get(i);
            JsonNode jsonNode3 = jsonNode2.get(0);
            JsonNode jsonNode4 = jsonNode2.get(1);
            int asInt = jsonNode3.get(0).asInt();
            for (int i2 = 0; i2 < jsonNode4.size(); i2++) {
                arrayList2.add(_listToNode(jsonNode4.get(i2), asInt));
            }
            arrayList.add(arrayList2);
        }
        return new Embedding(arrayList);
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [org.bigml.mimir.deepnet.network.Embedding$Node[], org.bigml.mimir.deepnet.network.Embedding$Node[][]] */
    public Embedding(List<List<Node>> list) {
        this._nodes = new Node[list.size()];
        for (int i = 0; i < list.size(); i++) {
            this._nodes[i] = new Node[list.get(i).size()];
            for (int i2 = 0; i2 < this._nodes[i].length; i2++) {
                this._nodes[i][i2] = list.get(i).get(i2);
            }
        }
        Node node = this._nodes[0][0];
        while (true) {
            Node node2 = node;
            if (node2._left == null) {
                this._nclasses = node2._coordinates.length;
                this._outputSize = list.size() * this._nclasses;
                return;
            }
            node = node2._left;
        }
    }

    public double[] embed(double[] dArr) {
        double[] dArr2 = new double[this._outputSize];
        for (int i = 0; i < this._nodes.length; i++) {
            int i2 = i * this._nclasses;
            for (Node node : this._nodes[i]) {
                double[] predict = predict(node, dArr);
                for (int i3 = 0; i3 < predict.length; i3++) {
                    int i4 = i2 + i3;
                    dArr2[i4] = dArr2[i4] + predict[i3];
                }
            }
            for (int i5 = i2; i5 < i2 + this._nclasses; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] / r0.length;
            }
        }
        return dArr2;
    }

    private static Node _listToNode(JsonNode jsonNode, int i) {
        if (jsonNode.get(jsonNode.size() - 1).isNull()) {
            return new Node(Json.get1DArray(jsonNode.get(0)));
        }
        return new Node(jsonNode.get(0).asInt() + i, jsonNode.get(1).asDouble(), _listToNode(jsonNode.get(2), i), _listToNode(jsonNode.get(3), i));
    }

    private static double[] predict(Node node, double[] dArr) {
        Node node2 = node;
        while (true) {
            Node node3 = node2;
            if (node3._coordinates != null) {
                return node3._coordinates;
            }
            node2 = dArr[node3._splitIndex] < node3._splitValue ? node3._left : node3._right;
        }
    }
}
