package net.librec.recommender.cf.rating;

import java.util.ArrayList;
import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

/* loaded from: input_file:net/librec/recommender/cf/rating/BPMFRecommender.class */
public class BPMFRecommender extends MatrixFactorizationRecommender {
    private double userMu0;
    private double userBeta0;
    private double userWishartScale0;
    private double itemMu0;
    private double itemBeta0;
    private double itemWishartScale0;
    private DenseVector userMu;
    private DenseVector itemMu;
    private DenseMatrix userWishartScale;
    private DenseMatrix itemWishartScale;
    private double userBeta;
    private double itemBeta;
    private double userWishartNu;
    private double itemWishartNu;
    private double ratingSigma;
    private SparseMatrix predictMatrix;

    /* loaded from: input_file:net/librec/recommender/cf/rating/BPMFRecommender$HyperParameters.class */
    public class HyperParameters {
        public DenseVector mu;
        public DenseMatrix variance;

        HyperParameters(DenseVector denseVector, DenseMatrix denseMatrix) {
            this.mu = denseVector;
            this.variance = denseMatrix;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.userMu0 = this.conf.getDouble("rec.recommender.user.mu", Double.valueOf(0.0d)).doubleValue();
        this.userBeta0 = this.conf.getDouble("rec.recommender.user.beta", Double.valueOf(1.0d)).doubleValue();
        this.userWishartScale0 = this.conf.getDouble("rec.recommender.user.wishart.scale", Double.valueOf(1.0d)).doubleValue();
        this.itemMu0 = this.conf.getDouble("rec.recommender.item.mu", Double.valueOf(0.0d)).doubleValue();
        this.itemBeta0 = this.conf.getDouble("rec.recommender.item.beta", Double.valueOf(1.0d)).doubleValue();
        this.itemWishartScale0 = this.conf.getDouble("rec.recommender.item.wishart.scale", Double.valueOf(1.0d)).doubleValue();
        this.ratingSigma = this.conf.getDouble("rec.recommender.rating.sigma", Double.valueOf(2.0d)).doubleValue();
    }

    protected void initModel() throws LibrecException {
        this.userMu = new DenseVector(this.numFactors);
        this.userMu.setAll(this.userMu0);
        this.itemMu = new DenseVector(this.numFactors);
        this.itemMu.setAll(this.itemMu0);
        this.userBeta = this.userBeta0;
        this.itemBeta = this.itemBeta0;
        this.userWishartScale = new DenseMatrix(this.numFactors, this.numFactors);
        this.itemWishartScale = new DenseMatrix(this.numFactors, this.numFactors);
        for (int i = 0; i < this.numFactors; i++) {
            this.userWishartScale.set(i, i, this.userWishartScale0);
            this.itemWishartScale.set(i, i, this.itemWishartScale0);
        }
        this.userWishartScale.inv();
        this.itemWishartScale.inv();
        this.userWishartNu = this.numFactors;
        this.itemWishartNu = this.numFactors;
        this.predictMatrix = new SparseMatrix(this.testMatrix);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        initModel();
        ArrayList arrayList = new ArrayList(this.numUsers);
        ArrayList arrayList2 = new ArrayList(this.numItems);
        for (int i = 0; i < this.numUsers; i++) {
            arrayList.add(this.trainMatrix.row(i));
        }
        for (int i2 = 0; i2 < this.numItems; i2++) {
            arrayList2.add(this.trainMatrix.column(i2));
        }
        DenseVector denseVector = new DenseVector(this.numFactors);
        DenseVector denseVector2 = new DenseVector(this.numFactors);
        for (int i3 = 0; i3 < this.numFactors; i3++) {
            denseVector.set(i3, this.userFactors.columnMean(i3));
            denseVector2.set(i3, this.itemFactors.columnMean(i3));
        }
        DenseMatrix inv = this.userFactors.cov().inv();
        DenseMatrix inv2 = this.itemFactors.cov().inv();
        HyperParameters hyperParameters = new HyperParameters(denseVector, inv);
        HyperParameters hyperParameters2 = new HyperParameters(denseVector2, inv2);
        for (int i4 = 0; i4 < this.numIterations; i4++) {
            hyperParameters = samplingHyperParameters(hyperParameters, this.userFactors, this.userMu, this.userBeta, this.userWishartScale, this.userWishartNu);
            hyperParameters2 = samplingHyperParameters(hyperParameters2, this.itemFactors, this.itemMu, this.itemBeta, this.itemWishartScale, this.itemWishartNu);
            for (int i5 = 0; i5 < 1; i5++) {
                for (int i6 = 0; i6 < this.numUsers; i6++) {
                    SparseVector sparseVector = (SparseVector) arrayList.get(i6);
                    if (sparseVector.getCount() != 0) {
                        this.userFactors.setRow(i6, updateParameters(this.itemFactors, sparseVector, hyperParameters));
                    }
                }
                for (int i7 = 0; i7 < this.numItems; i7++) {
                    SparseVector sparseVector2 = (SparseVector) arrayList2.get(i7);
                    if (sparseVector2.getCount() != 0) {
                        this.itemFactors.setRow(i7, updateParameters(this.userFactors, sparseVector2, hyperParameters2));
                    }
                }
            }
            if (i4 == 1) {
                Iterator<MatrixEntry> it = this.testMatrix.iterator();
                while (it.hasNext()) {
                    MatrixEntry next = it.next();
                    this.predictMatrix.set(next.row(), next.column(), 0.0d);
                }
            }
            if (i4 > 0) {
                Iterator<MatrixEntry> it2 = this.testMatrix.iterator();
                while (it2.hasNext()) {
                    MatrixEntry next2 = it2.next();
                    int row = next2.row();
                    int column = next2.column();
                    this.predictMatrix.set(row, column, (((this.predictMatrix.get(row, column) * ((i4 - 1) - 0)) + this.globalMean) + DenseMatrix.rowMult(this.userFactors, row, this.itemFactors, column)) / (i4 - 0));
                }
            }
        }
    }

    protected HyperParameters samplingHyperParameters(HyperParameters hyperParameters, DenseMatrix denseMatrix, DenseVector denseVector, double d, DenseMatrix denseMatrix2, double d2) throws LibrecException {
        int numRows = denseMatrix.numRows();
        int numColumns = denseMatrix.numColumns();
        DenseVector denseVector2 = new DenseVector(this.numFactors);
        for (int i = 0; i < numColumns; i++) {
            denseVector2.set(i, denseMatrix.columnMean(i));
        }
        DenseMatrix cov = denseMatrix.cov();
        double d3 = d + numRows;
        double d4 = d2 + 1.0d;
        DenseVector scale = denseVector.scale(d).add(denseVector2.scale(numRows)).scale(1.0d / d3);
        DenseMatrix add = denseMatrix2.add(cov.scale(numRows));
        DenseVector minus = denseVector.minus(denseVector2);
        DenseMatrix inv = add.add(minus.outer(minus).scale((d * numRows) / d3)).inv();
        DenseMatrix wishart = Randoms.wishart(inv.add(inv.transpose()).scale(0.5d), numRows + numColumns);
        if (wishart != null) {
            hyperParameters.variance = wishart;
        }
        DenseMatrix cholesky = hyperParameters.variance.scale(d).inv().cholesky();
        if (cholesky != null) {
            DenseMatrix transpose = cholesky.transpose();
            DenseVector denseVector3 = new DenseVector(numColumns);
            for (int i2 = 0; i2 < this.numFactors; i2++) {
                denseVector3.set(i2, Randoms.gaussian(0.0d, 1.0d));
            }
            hyperParameters.mu = transpose.mult(denseVector3).add(scale);
        }
        return hyperParameters;
    }

    protected DenseVector updateParameters(DenseMatrix denseMatrix, SparseVector sparseVector, HyperParameters hyperParameters) throws LibrecException {
        int count = sparseVector.getCount();
        DenseMatrix denseMatrix2 = new DenseMatrix(count, this.numFactors);
        DenseVector denseVector = new DenseVector(count);
        int i = 0;
        for (int i2 : sparseVector.getIndex()) {
            denseVector.set(i, sparseVector.get(i2) - this.globalMean);
            denseMatrix2.setRow(i, denseMatrix.row(i2));
            i++;
        }
        DenseMatrix inv = hyperParameters.variance.add(denseMatrix2.transpose().mult(denseMatrix2).scale(this.ratingSigma)).inv();
        DenseVector scale = denseMatrix2.transpose().mult(denseVector).scale(this.ratingSigma);
        scale.addEqual(hyperParameters.variance.mult(hyperParameters.mu));
        DenseVector mult = inv.mult(scale);
        DenseVector denseVector2 = new DenseVector(this.numFactors);
        DenseMatrix cholesky = inv.cholesky();
        if (cholesky != null) {
            DenseMatrix transpose = cholesky.transpose();
            for (int i3 = 0; i3 < this.numFactors; i3++) {
                denseVector2.set(i3, Randoms.gaussian(0.0d, 1.0d));
            }
            DenseVector add = transpose.mult(denseVector2).add(mult);
            for (int i4 = 0; i4 < this.numFactors; i4++) {
                denseVector2.set(i4, add.get(i4));
            }
        }
        return denseVector2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) {
        return this.predictMatrix.get(i, i2);
    }
}
