package net.librec.recommender.cf.ranking;

import com.google.common.cache.LoadingCache;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData({"isRanking", "gbpr", "userFactors", "itemFactors", "trainMatrix"})
/* loaded from: input_file:net/librec/recommender/cf/ranking/GBPRRecommender.class */
public class GBPRRecommender extends MatrixFactorizationRecommender {
    private float rho;
    private int gLen;
    protected double regBias;
    private DenseVector itemBiases;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected LoadingCache<Integer, List<Integer>> itemUsersCache;
    protected static String cacheSpec;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.itemBiases = new DenseVector(this.numItems);
        this.itemBiases.init();
        this.rho = this.conf.getFloat("rec.gpbr.rho", Float.valueOf(1.5f)).floatValue();
        this.gLen = this.conf.getInt("rec.gpbr.gsize", 2).intValue();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.itemUsersCache = this.trainMatrix.columnRowsCache(cacheSpec);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        int uniform;
        int uniform2;
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            DenseMatrix denseMatrix = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix denseMatrix2 = new DenseMatrix(this.numItems, this.numFactors);
            int i2 = this.numUsers * 100;
            for (int i3 = 0; i3 < i2; i3++) {
                List list = null;
                do {
                    uniform = Randoms.uniform(this.trainMatrix.numRows());
                    try {
                        list = (List) this.userItemsCache.get(Integer.valueOf(uniform));
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                } while (list.size() == 0);
                int intValue = ((Integer) Randoms.random(list)).intValue();
                List list2 = null;
                try {
                    list2 = (List) this.itemUsersCache.get(Integer.valueOf(intValue));
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
                HashSet hashSet = new HashSet();
                if (list2.size() <= this.gLen) {
                    hashSet.addAll(list2);
                } else {
                    hashSet.add(Integer.valueOf(uniform));
                    while (hashSet.size() < this.gLen) {
                        int intValue2 = ((Integer) Randoms.random(list2)).intValue();
                        if (!hashSet.contains(Integer.valueOf(intValue2))) {
                            hashSet.add(Integer.valueOf(intValue2));
                        }
                    }
                }
                double predict = predict(uniform, intValue, hashSet);
                do {
                    uniform2 = Randoms.uniform(this.numItems);
                } while (list.contains(Integer.valueOf(uniform2)));
                double predict2 = predict - predict(uniform, uniform2);
                this.loss += -Math.log(Maths.logistic(predict2));
                double logistic = Maths.logistic(-predict2);
                this.itemBiases.add(intValue, this.learnRate * (logistic - (this.regBias * this.itemBiases.get(intValue))));
                this.itemBiases.add(uniform2, this.learnRate * ((-logistic) - (this.regBias * this.itemBiases.get(uniform2))));
                double size = 1.0d / hashSet.size();
                double[] dArr = new double[this.numFactors];
                Iterator<Integer> it = hashSet.iterator();
                while (it.hasNext()) {
                    int intValue3 = it.next().intValue();
                    double d = intValue3 == uniform ? 1.0d : 0.0d;
                    for (int i4 = 0; i4 < this.numFactors; i4++) {
                        double d2 = this.userFactors.get(intValue3, i4);
                        double d3 = this.itemFactors.get(intValue, i4);
                        denseMatrix.add(intValue3, i4, this.learnRate * ((logistic * ((((this.rho * size) * d3) + (((1.0f - this.rho) * d) * d3)) - (d * this.itemFactors.get(uniform2, i4)))) - (this.regUser * d2)));
                        int i5 = i4;
                        dArr[i5] = dArr[i5] + d2;
                    }
                }
                for (int i6 = 0; i6 < this.numFactors; i6++) {
                    double d4 = this.userFactors.get(uniform, i6);
                    double d5 = this.itemFactors.get(intValue, i6);
                    double d6 = this.itemFactors.get(uniform2, i6);
                    denseMatrix2.add(intValue, i6, this.learnRate * ((logistic * (((this.rho * size) * dArr[i6]) + ((1.0f - this.rho) * d4))) - (this.regItem * d5)));
                    denseMatrix2.add(uniform2, i6, this.learnRate * ((logistic * (-d4)) - (this.regItem * d6)));
                }
            }
            this.userFactors.addEqual(denseMatrix);
            this.itemFactors.addEqual(denseMatrix2);
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    protected double predict(int i, int i2, Set<Integer> set) throws LibrecException {
        double predict = predict(i, i2);
        double d = 0.0d;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            d += DenseMatrix.rowMult(this.userFactors, it.next().intValue(), this.itemFactors, i2);
        }
        return (this.rho * ((d / set.size()) + this.itemBiases.get(i2))) + ((1.0f - this.rho) * predict);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) {
        return this.itemBiases.get(i2) + DenseMatrix.rowMult(this.userFactors, i, this.itemFactors, i2);
    }
}
