package net.librec.recommender.context.rating;

import com.google.common.cache.LoadingCache;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.SocialRecommender;

@ModelData({"isRating", "trustsvd", "userFactors", "itemFactors", "impItemFactors", "userBiases", "itemBiases", "socialMatrix", "trainMatrix"})
/* loaded from: input_file:net/librec/recommender/context/rating/TrustSVDRecommender.class */
public class TrustSVDRecommender extends SocialRecommender {
    private DenseMatrix impItemFactors;
    private DenseMatrix trusteeFactors;
    private DenseVector trusteeWeights;
    private DenseVector trusterWeights;
    private DenseVector impItemWeights;
    private DenseVector userBiases;
    private DenseVector itemBiases;
    protected double regBias;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected LoadingCache<Integer, List<Integer>> userTrusteeCache;
    protected static String cacheSpec;

    @Override // net.librec.recommender.SocialRecommender, net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.regBias = this.conf.getDouble("rec.bias.regularization", Double.valueOf(0.01d)).doubleValue();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userBiases = new DenseVector(this.numUsers);
        this.itemBiases = new DenseVector(this.numItems);
        this.userBiases.init(this.initMean, this.initStd);
        this.itemBiases.init(this.initMean, this.initStd);
        this.trusteeFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.impItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.trusteeFactors.init(this.initMean, this.initStd);
        this.impItemFactors.init(this.initMean, this.initStd);
        this.trusteeWeights = new DenseVector(this.numUsers);
        this.trusterWeights = new DenseVector(this.numUsers);
        this.impItemWeights = new DenseVector(this.numItems);
        for (int i = 0; i < this.numUsers; i++) {
            int columnSize = this.socialMatrix.columnSize(i);
            this.trusteeWeights.set(i, columnSize > 0 ? 1.0d / Math.sqrt(columnSize) : 1.0d);
            int rowSize = this.socialMatrix.rowSize(i);
            this.trusterWeights.set(i, rowSize > 0 ? 1.0d / Math.sqrt(rowSize) : 1.0d);
        }
        for (int i2 = 0; i2 < this.numItems; i2++) {
            int columnSize2 = this.trainMatrix.columnSize(i2);
            this.impItemWeights.set(i2, columnSize2 > 0 ? 1.0d / Math.sqrt(columnSize2) : 1.0d);
        }
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.userTrusteeCache = this.socialMatrix.rowColumnsCache(cacheSpec);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            DenseMatrix denseMatrix = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix denseMatrix2 = new DenseMatrix(this.numUsers, this.numFactors);
            Iterator<MatrixEntry> it = this.trainMatrix.iterator();
            while (it.hasNext()) {
                MatrixEntry next = it.next();
                int row = next.row();
                int column = next.column();
                double d = next.get();
                double d2 = this.userBiases.get(row);
                double d3 = this.itemBiases.get(column);
                double rowMult = this.globalMean + d2 + d3 + DenseMatrix.rowMult(this.userFactors, row, this.itemFactors, column);
                List list = null;
                try {
                    list = (List) this.userItemsCache.get(Integer.valueOf(row));
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
                if (list.size() > 0) {
                    double d4 = 0.0d;
                    Iterator it2 = list.iterator();
                    while (it2.hasNext()) {
                        d4 += DenseMatrix.rowMult(this.impItemFactors, ((Integer) it2.next()).intValue(), this.itemFactors, column);
                    }
                    rowMult += d4 / Math.sqrt(list.size());
                }
                List list2 = null;
                try {
                    list2 = (List) this.userTrusteeCache.get(Integer.valueOf(row));
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
                if (list2.size() > 0) {
                    double d5 = 0.0d;
                    Iterator it3 = list2.iterator();
                    while (it3.hasNext()) {
                        d5 += DenseMatrix.rowMult(this.trusteeFactors, ((Integer) it3.next()).intValue(), this.itemFactors, column);
                    }
                    rowMult += d5 / Math.sqrt(list2.size());
                }
                double d6 = rowMult - d;
                this.loss += d6 * d6;
                double sqrt = Math.sqrt(list.size());
                double sqrt2 = Math.sqrt(list2.size());
                double d7 = 1.0d / sqrt;
                double d8 = this.impItemWeights.get(column);
                this.userBiases.add(row, (-this.learnRate) * (d6 + (this.regBias * d7 * d2)));
                this.itemBiases.add(column, (-this.learnRate) * (d6 + (this.regBias * d8 * d3)));
                this.loss += (this.regBias * d7 * d2 * d2) + (this.regBias * d8 * d3 * d3);
                double[] dArr = new double[this.numFactors];
                for (int i2 = 0; i2 < this.numFactors; i2++) {
                    double d9 = 0.0d;
                    Iterator it4 = list.iterator();
                    while (it4.hasNext()) {
                        d9 += this.impItemFactors.get(((Integer) it4.next()).intValue(), i2);
                    }
                    dArr[i2] = sqrt > 0.0d ? d9 / sqrt : d9;
                }
                double[] dArr2 = new double[this.numFactors];
                for (int i3 = 0; i3 < this.numFactors; i3++) {
                    double d10 = 0.0d;
                    Iterator it5 = list2.iterator();
                    while (it5.hasNext()) {
                        d10 += this.trusteeFactors.get(((Integer) it5.next()).intValue(), i3);
                    }
                    dArr2[i3] = sqrt2 > 0.0d ? d10 / sqrt2 : d10;
                }
                for (int i4 = 0; i4 < this.numFactors; i4++) {
                    double d11 = this.userFactors.get(row, i4);
                    double d12 = this.itemFactors.get(column, i4);
                    double d13 = (d6 * d12) + (this.regUser * d7 * d11);
                    double d14 = (d6 * (d11 + dArr[i4] + dArr2[i4])) + (this.regItem * d8 * d12);
                    denseMatrix.add(row, i4, d13);
                    this.itemFactors.add(column, i4, (-this.learnRate) * d14);
                    this.loss += (this.regUser * d7 * d11 * d11) + (this.regItem * d8 * d12 * d12);
                    Iterator it6 = list.iterator();
                    while (it6.hasNext()) {
                        int intValue = ((Integer) it6.next()).intValue();
                        double d15 = this.impItemFactors.get(intValue, i4);
                        double d16 = this.impItemWeights.get(intValue);
                        this.impItemFactors.add(intValue, i4, (-this.learnRate) * (((d6 * d12) / sqrt) + (this.regItem * d16 * d15)));
                        this.loss += this.regItem * d16 * d15 * d15;
                    }
                    Iterator it7 = list2.iterator();
                    while (it7.hasNext()) {
                        int intValue2 = ((Integer) it7.next()).intValue();
                        double d17 = this.trusteeFactors.get(intValue2, i4);
                        double d18 = this.trusteeWeights.get(intValue2);
                        denseMatrix2.add(intValue2, i4, ((d6 * d12) / sqrt2) + (this.regUser * d18 * d17));
                        this.loss += this.regUser * d18 * d17 * d17;
                    }
                }
            }
            Iterator<MatrixEntry> it8 = this.socialMatrix.iterator();
            while (it8.hasNext()) {
                MatrixEntry next2 = it8.next();
                int row2 = next2.row();
                int column2 = next2.column();
                double d19 = next2.get();
                if (d19 != 0.0d) {
                    double rowMult2 = DenseMatrix.rowMult(this.userFactors, row2, this.trusteeFactors, column2) - d19;
                    this.loss += this.regSocial * rowMult2 * rowMult2;
                    double d20 = this.regSocial * rowMult2;
                    double d21 = this.trusterWeights.get(row2);
                    for (int i5 = 0; i5 < this.numFactors; i5++) {
                        double d22 = this.userFactors.get(row2, i5);
                        denseMatrix.add(row2, i5, (d20 * this.trusteeFactors.get(column2, i5)) + (this.regSocial * d21 * d22));
                        denseMatrix2.add(column2, i5, d20 * d22);
                        this.loss += this.regSocial * d21 * d22 * d22;
                    }
                }
            }
            this.userFactors.addEqual(denseMatrix.scale(-this.learnRate));
            this.trusteeFactors.addEqual(denseMatrix2.scale(-this.learnRate));
            this.loss *= 0.5d;
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) throws LibrecException {
        double rowMult = this.globalMean + this.userBiases.get(i) + this.itemBiases.get(i2) + DenseMatrix.rowMult(this.userFactors, i, this.itemFactors, i2);
        List list = null;
        try {
            list = (List) this.userItemsCache.get(Integer.valueOf(i));
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
        if (list.size() > 0) {
            double d = 0.0d;
            Iterator it = list.iterator();
            while (it.hasNext()) {
                d += DenseMatrix.rowMult(this.impItemFactors, ((Integer) it.next()).intValue(), this.itemFactors, i2);
            }
            rowMult += d / Math.sqrt(list.size());
        }
        List list2 = null;
        try {
            list2 = (List) this.userTrusteeCache.get(Integer.valueOf(i));
        } catch (ExecutionException e2) {
            e2.printStackTrace();
        }
        if (list2.size() > 0) {
            double d2 = 0.0d;
            Iterator it2 = list2.iterator();
            while (it2.hasNext()) {
                d2 += DenseMatrix.rowMult(this.trusteeFactors, ((Integer) it2.next()).intValue(), this.itemFactors, i2);
            }
            rowMult += d2 / Math.sqrt(list2.size());
        }
        return rowMult;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.SocialRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2, boolean z) throws LibrecException {
        return predict(i, i2);
    }
}
