package net.librec.recommender.cf.ranking;

import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.MatrixFactorizationRecommender;

/* loaded from: input_file:net/librec/recommender/cf/ranking/ListRankMFRecommender.class */
public class ListRankMFRecommender extends MatrixFactorizationRecommender {
    public DenseVector userExp;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0d);
        this.userFactors.scale(0.1d);
        this.itemFactors.init(1.0d);
        this.itemFactors.scale(0.1d);
        this.userExp = new DenseVector(this.numUsers);
        Iterator<MatrixEntry> it = this.trainMatrix.iterator();
        while (it.hasNext()) {
            MatrixEntry next = it.next();
            this.userExp.add(next.row(), Math.exp(next.get() / this.maxRate));
        }
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        double loss = getLoss(this.userFactors, this.itemFactors);
        for (int i = 1; i <= this.numIterations; i++) {
            DenseMatrix denseMatrix = this.userFactors;
            DenseMatrix denseMatrix2 = this.itemFactors;
            this.learnRate *= 2.0f;
            DenseMatrix denseMatrix3 = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix denseMatrix4 = new DenseMatrix(this.numItems, this.numFactors);
            for (int i2 = 0; i2 < this.numUsers; i2++) {
                double d = 0.0d;
                Iterator<Integer> it = this.trainMatrix.getColumns(i2).iterator();
                while (it.hasNext()) {
                    d += Math.exp(Maths.logistic(DenseMatrix.rowMult(this.userFactors, i2, this.itemFactors, it.next().intValue())));
                }
                Iterator<VectorEntry> it2 = this.trainMatrix.row(i2).iterator();
                while (it2.hasNext()) {
                    VectorEntry next = it2.next();
                    int index = next.index();
                    double d2 = next.get() / this.maxRate;
                    double rowMult = DenseMatrix.rowMult(this.userFactors, i2, this.itemFactors, index);
                    double exp = ((Math.exp(Maths.logistic(rowMult)) / d) - (Math.exp(d2) / this.userExp.get(i2))) * Maths.logisticGradientValue(rowMult);
                    for (int i3 = 0; i3 < this.numFactors; i3++) {
                        double d3 = this.userFactors.get(i2, i3);
                        denseMatrix3.add(i2, i3, exp * this.itemFactors.get(index, i3));
                        denseMatrix4.add(index, i3, exp * d3);
                    }
                }
            }
            this.userFactors = this.userFactors.add(this.userFactors.scale((-this.learnRate) * this.regUser));
            this.userFactors = this.userFactors.add(denseMatrix3.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(this.itemFactors.scale((-this.learnRate) * this.regItem));
            this.itemFactors = this.itemFactors.add(denseMatrix4.scale(-this.learnRate));
            this.loss = getLoss(this.userFactors, this.itemFactors);
            while (this.loss > loss) {
                this.userFactors = denseMatrix;
                this.itemFactors = denseMatrix2;
                this.learnRate /= 2.0f;
                this.userFactors = this.userFactors.add(this.userFactors.scale((-this.learnRate) * this.regUser));
                this.userFactors = this.userFactors.add(denseMatrix3.scale(-this.learnRate));
                this.itemFactors = this.itemFactors.add(this.itemFactors.scale((-this.learnRate) * this.regItem));
                this.itemFactors = this.itemFactors.add(denseMatrix4.scale(-this.learnRate));
                this.loss = getLoss(this.userFactors, this.itemFactors);
            }
            this.LOG.info(" iter " + i + ": loss = " + this.loss + ", delta_loss = " + (loss - this.loss));
            loss = this.loss;
        }
    }

    public double getLoss(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        double d = 0.0d;
        for (int i = 0; i < this.numUsers; i++) {
            double d2 = 0.0d;
            Iterator<Integer> it = this.trainMatrix.getColumns(i).iterator();
            while (it.hasNext()) {
                d2 += Math.exp(Maths.logistic(DenseMatrix.rowMult(denseMatrix, i, denseMatrix2, it.next().intValue())));
            }
            Iterator<VectorEntry> colIterator = this.trainMatrix.colIterator(i);
            while (colIterator.hasNext()) {
                VectorEntry next = colIterator.next();
                d -= (Math.exp(next.get() / this.maxRate) / this.userExp.get(i)) * Math.log(Math.exp(Maths.logistic(DenseMatrix.rowMult(denseMatrix, i, denseMatrix2, next.index()))) / d2);
            }
            for (int i2 = 0; i2 < this.numFactors; i2++) {
                double d3 = denseMatrix.get(i, i2);
                d += 0.5d * this.regUser * d3 * d3;
            }
        }
        for (int i3 = 0; i3 < this.numItems; i3++) {
            for (int i4 = 0; i4 < this.numFactors; i4++) {
                double d4 = denseMatrix2.get(i3, i4);
                d += 0.5d * this.regItem * d4 * d4;
            }
        }
        return d;
    }
}
