package de.datexis.rnn.loss;

import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/rnn/loss/DosSantosPairwiseRankingLoss.class */
public class DosSantosPairwiseRankingLoss implements ILossFunction {
    protected static final Logger log = LoggerFactory.getLogger(DosSantosPairwiseRankingLoss.class);
    private Number positiveClassExclusionFactor;
    private Number gamma;
    private Number mPlus;
    private Number mMinus;

    public DosSantosPairwiseRankingLoss() {
        this.positiveClassExclusionFactor = -1000000;
        this.gamma = 2;
        this.mPlus = Double.valueOf(2.5d);
        this.mMinus = Double.valueOf(0.5d);
    }

    public DosSantosPairwiseRankingLoss(Number number, Number number2, Number number3) {
        this.positiveClassExclusionFactor = -1000000;
        this.gamma = 2;
        this.mPlus = Double.valueOf(2.5d);
        this.mMinus = Double.valueOf(0.5d);
        this.gamma = number;
        this.mPlus = number2;
        this.mMinus = number3;
    }

    public DosSantosPairwiseRankingLoss(int i, int i2, int i3, int i4) {
        this.positiveClassExclusionFactor = -1000000;
        this.gamma = 2;
        this.mPlus = Double.valueOf(2.5d);
        this.mMinus = Double.valueOf(0.5d);
        this.gamma = Integer.valueOf(i);
        this.mPlus = Integer.valueOf(i2);
        this.mMinus = Integer.valueOf(i3);
        this.positiveClassExclusionFactor = Integer.valueOf(i4);
    }

    public INDArray scoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        INDArray addi = Nd4j.zeros(activation.shape()).addi(this.mPlus);
        INDArray addi2 = Nd4j.zeros(activation.shape()).addi(this.mMinus);
        INDArray log2 = Transforms.log(Transforms.exp(addi.subi(activation).muli(this.gamma)).addi(1));
        INDArray exp = Transforms.exp(addi2.addi(activation).muli(this.gamma));
        INDArray argMax = activation.addi(iNDArray.mul(this.positiveClassExclusionFactor)).argMax(new int[]{1});
        INDArray log3 = Transforms.log(exp.addi(1));
        INDArray zeros = Nd4j.zeros(iNDArray2.shape());
        for (int i = 0; i < argMax.length(); i++) {
            zeros.put(i, argMax.getInt(new int[]{i}), 1);
        }
        INDArray addi3 = log2.muli(iNDArray).addi(log3.muli(zeros));
        if (iNDArray3 != null) {
            addi3.muliColumnVector(iNDArray3);
        }
        return addi3;
    }

    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double doubleValue = scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sumNumber().doubleValue();
        if (z) {
            doubleValue /= r0.size(0);
        }
        return doubleValue;
    }

    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sum(new int[]{1});
    }

    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        INDArray iNDArray4 = (INDArray) iActivation.backprop(iNDArray2.dup(), computeDlDx(iNDArray, iActivation.getActivation(iNDArray2.dup(), true))).getFirst();
        if (iNDArray3 != null) {
            iNDArray4.muliColumnVector(iNDArray3);
        }
        return iNDArray4;
    }

    public INDArray computeDlDx(INDArray iNDArray, INDArray iNDArray2) {
        INDArray addi = Nd4j.zeros(iNDArray2.shape()).addi(this.mPlus);
        INDArray addi2 = Nd4j.zeros(iNDArray2.shape()).addi(this.mMinus);
        INDArray exp = Transforms.exp(addi.subi(iNDArray2).muli(this.gamma));
        INDArray divi = exp.mul(this.gamma).divi(exp.addi(1));
        divi.negi();
        INDArray exp2 = Transforms.exp(addi2.addi(iNDArray2).muli(this.gamma));
        INDArray divi2 = exp2.mul(this.gamma).divi(exp2.addi(1));
        INDArray addi3 = iNDArray2.addi(iNDArray.mul(this.positiveClassExclusionFactor));
        INDArray argMax = addi3.argMax(new int[]{1});
        INDArray zeros = Nd4j.zeros(addi3.shape());
        for (int i = 0; i < argMax.length(); i++) {
            zeros.put(i, argMax.getInt(new int[]{i}), 1);
        }
        return divi.muli(iNDArray).addi(divi2.muli(zeros));
    }

    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
    }

    public String name() {
        return getClass().getSimpleName();
    }

    public Number getGamma() {
        return this.gamma;
    }

    public void setGamma(Number number) {
        this.gamma = number;
    }

    public Number getmPlus() {
        return this.mPlus;
    }

    public void setmPlus(Number number) {
        this.mPlus = number;
    }

    public Number getmMinus() {
        return this.mMinus;
    }

    public void setmMinus(Number number) {
        this.mMinus = number;
    }

    public Number getPositiveClassExclusionFactor() {
        return this.positiveClassExclusionFactor;
    }

    public void setPositiveClassExclusionFactor(Number number) {
        this.positiveClassExclusionFactor = number;
    }
}
