package org.deeplearning4j.graph.models.embeddings;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.graph.models.BinaryTree;
import org.deeplearning4j.graph.models.deepwalk.DeepWalk;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/graph/models/embeddings/InMemoryGraphLookupTable.class */
public class InMemoryGraphLookupTable implements GraphVectorLookupTable {
    protected int nVertices;
    protected int vectorSize;
    protected BinaryTree tree;
    protected INDArray vertexVectors;
    protected INDArray outWeights;
    protected double learningRate;
    protected double[] expTable;
    protected static double MAX_EXP = 6.0d;

    public InMemoryGraphLookupTable(int i, int i2, BinaryTree binaryTree, double d) {
        this.nVertices = i;
        this.vectorSize = i2;
        this.tree = binaryTree;
        this.learningRate = d;
        resetWeights();
        this.expTable = new double[DeepWalk.STATUS_UPDATE_FREQUENCY];
        for (int i3 = 0; i3 < this.expTable.length; i3++) {
            double exp = FastMath.exp((((i3 / this.expTable.length) * 2.0d) - 1.0d) * MAX_EXP);
            this.expTable[i3] = exp / (exp + 1.0d);
        }
    }

    public INDArray getVertexVectors() {
        return this.vertexVectors;
    }

    public INDArray getOutWeights() {
        return this.outWeights;
    }

    @Override // org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable
    public int vectorSize() {
        return this.vectorSize;
    }

    @Override // org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable
    public void resetWeights() {
        this.vertexVectors = Nd4j.rand(new int[]{this.nVertices, this.vectorSize}).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorSize));
        this.outWeights = Nd4j.rand(new int[]{this.nVertices - 1, this.vectorSize}).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorSize));
    }

    @Override // org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable
    public void iterate(int i, int i2) {
        INDArray[][] vectorsAndGradients = vectorsAndGradients(i, i2);
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        for (int i3 = 0; i3 < vectorsAndGradients[0].length; i3++) {
            level1.axpy(vectorsAndGradients[0][i3].length(), -this.learningRate, vectorsAndGradients[1][i3], vectorsAndGradients[0][i3]);
        }
    }

    public INDArray[][] vectorsAndGradients(int i, int i2) {
        INDArray mul;
        INDArray row = this.vertexVectors.getRow(i);
        int codeLength = this.tree.getCodeLength(i2);
        long code = this.tree.getCode(i2);
        int[] pathInnerNodes = this.tree.getPathInnerNodes(i2);
        INDArray[][] iNDArrayArr = new INDArray[2][pathInnerNodes.length + 1];
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        INDArray create = Nd4j.create(row.shape());
        for (int i3 = 0; i3 < codeLength; i3++) {
            int i4 = pathInnerNodes[i3];
            boolean bit = getBit(code, i3);
            INDArray row2 = this.outWeights.getRow(i4);
            double sigmoid = sigmoid(Nd4j.getBlasWrapper().dot(row2, row));
            if (bit) {
                mul = row.mul(Double.valueOf(sigmoid - 1.0d));
                level1.axpy(row.length(), sigmoid - 1.0d, row2, create);
            } else {
                mul = row.mul(Double.valueOf(sigmoid));
                level1.axpy(row.length(), sigmoid, row2, create);
            }
            iNDArrayArr[0][i3 + 1] = row2;
            iNDArrayArr[1][i3 + 1] = mul;
        }
        iNDArrayArr[0][0] = row;
        iNDArrayArr[1][0] = create;
        return iNDArrayArr;
    }

    public double calculateProb(int i, int i2) {
        INDArray row = this.vertexVectors.getRow(i);
        int codeLength = this.tree.getCodeLength(i2);
        long code = this.tree.getCode(i2);
        int[] pathInnerNodes = this.tree.getPathInnerNodes(i2);
        double d = 1.0d;
        for (int i3 = 0; i3 < codeLength; i3++) {
            boolean bit = getBit(code, i3);
            double dot = Nd4j.getBlasWrapper().dot(this.outWeights.getRow(pathInnerNodes[i3]), row);
            d *= bit ? sigmoid(dot) : sigmoid(-dot);
        }
        return d;
    }

    public double calculateScore(int i, int i2) {
        return -FastMath.log(calculateProb(i, i2));
    }

    public BinaryTree getTree() {
        return this.tree;
    }

    public INDArray getInnerNodeVector(int i) {
        return this.outWeights.getRow(i);
    }

    @Override // org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable
    public INDArray getVector(int i) {
        return this.vertexVectors.getRow(i);
    }

    @Override // org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable
    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    @Override // org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable
    public int getNumVertices() {
        return this.nVertices;
    }

    private static double sigmoid(double d) {
        return 1.0d / (1.0d + FastMath.exp(-d));
    }

    private boolean getBit(long j, int i) {
        return (j & (1 << i)) != 0;
    }

    public void setVertexVectors(INDArray iNDArray) {
        this.vertexVectors = iNDArray;
    }
}
