package net.librec.recommender;

import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.item.RecommendedItemList;
import net.librec.recommender.item.RecommendedList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:net/librec/recommender/FactorizationMachineRecommender.class */
public abstract class FactorizationMachineRecommender extends AbstractRecommender {
    protected final Log LOG = LogFactory.getLog(getClass());
    protected SparseTensor trainTensor;
    protected SparseTensor testTensor;
    protected SparseTensor validTensor;
    protected double w0;
    protected int p;
    protected int k;
    protected int n;
    protected DenseVector W;
    protected DenseMatrix V;
    protected DenseMatrix Q;
    protected float regW0;
    protected float regW;
    protected float regF;
    protected int numFactors;
    protected int numIterations;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        this.conf = this.context.getConf();
        this.isRanking = this.conf.getBoolean("rec.recommender.isranking");
        if (this.isRanking) {
            this.topN = this.conf.getInt("rec.recommender.ranking.topn", 5).intValue();
        }
        this.earlyStop = this.conf.getBoolean("rec.recommender.earlyStop");
        this.numIterations = this.conf.getInt("rec.iterator.maximum").intValue();
        this.trainTensor = (SparseTensor) getDataModel().getTrainDataSet();
        this.testTensor = (SparseTensor) getDataModel().getTestDataSet();
        this.validTensor = (SparseTensor) getDataModel().getValidDataSet();
        this.userMappingData = getDataModel().getUserMappingData();
        this.itemMappingData = getDataModel().getItemMappingData();
        this.numUsers = this.userMappingData.size();
        this.numItems = this.itemMappingData.size();
        this.globalMean = this.trainTensor.mean();
        this.maxRate = this.conf.getDouble("rec.recommender.maxrate", Double.valueOf(12.0d)).doubleValue();
        this.minRate = this.conf.getDouble("rec.recommender.minrate", Double.valueOf(0.0d)).doubleValue();
        for (int i = 0; i < this.trainTensor.numDimensions; i++) {
            this.p += this.trainTensor.dimensions[i];
        }
        this.n = this.trainTensor.size();
        int intValue = this.conf.getInt("rec.factor.number").intValue();
        this.k = intValue;
        this.numFactors = intValue;
        this.w0 = 0.0d;
        this.W = new DenseVector(this.p);
        this.W.init(0.0d);
        this.V = new DenseMatrix(this.p, this.k);
        this.V.init(0.0d, 0.1d);
        this.regW0 = this.conf.getFloat("rec.fm.regw0", Float.valueOf(0.01f)).floatValue();
        this.regW = this.conf.getFloat("rec.fm.regW", Float.valueOf(0.01f)).floatValue();
        this.regF = this.conf.getFloat("rec.fm.regF", Float.valueOf(10.0f)).floatValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double predict(int i, int i2, SparseVector sparseVector) throws LibrecException {
        double d = 0.0d + this.w0;
        Iterator<VectorEntry> it = sparseVector.iterator();
        while (it.hasNext()) {
            VectorEntry next = it.next();
            d += next.get() * this.W.get(next.index());
        }
        for (int i3 = 1; i3 < this.k; i3++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            Iterator<VectorEntry> it2 = sparseVector.iterator();
            while (it2.hasNext()) {
                VectorEntry next2 = it2.next();
                double d4 = next2.get();
                double d5 = this.V.get(next2.index(), i3);
                d2 += d5 * d4;
                d3 += d5 * d5 * d4 * d4;
            }
            d += ((d2 * d2) - d3) / 2.0d;
        }
        return d;
    }

    protected double predict(int i, int i2, SparseVector sparseVector, boolean z) throws LibrecException {
        double predict = predict(i, i2, sparseVector);
        if (z) {
            if (predict > this.maxRate) {
                predict = this.maxRate;
            }
            if (predict < this.minRate) {
                predict = this.minRate;
            }
        }
        return predict;
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected RecommendedList recommendRating() throws LibrecException {
        this.testMatrix = this.testTensor.rateMatrix();
        this.recommendedList = new RecommendedItemList(this.numUsers - 1, this.numUsers);
        HashBasedTable create = HashBasedTable.create();
        int userDimension = this.testTensor.getUserDimension();
        int itemDimension = this.testTensor.getItemDimension();
        Iterator<TensorEntry> it = this.testTensor.iterator();
        while (it.hasNext()) {
            int[] keys = it.next().keys();
            SparseVector tenserKeysToFeatureVector = tenserKeysToFeatureVector(keys);
            double predict = predict(keys[userDimension], keys[itemDimension], tenserKeysToFeatureVector, true);
            if (Double.isNaN(predict)) {
                predict = this.globalMean;
            }
            int[] userItemIndex = getUserItemIndex(tenserKeysToFeatureVector);
            int i = userItemIndex[0];
            int i2 = userItemIndex[1];
            if (!create.contains(Integer.valueOf(i), Integer.valueOf(i2))) {
                create.put(Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(predict));
                this.recommendedList.addUserItemIdx(i, i2, predict);
            }
        }
        return this.recommendedList;
    }

    private int[] getUserItemIndex(SparseVector sparseVector) {
        int[] index = sparseVector.getIndex();
        return new int[]{index[0], index[1] - this.numUsers};
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SparseVector tenserKeysToFeatureVector(int[] iArr) {
        int i = this.p;
        int[] iArr2 = new int[iArr.length];
        double[] dArr = new double[iArr.length];
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            dArr[i3] = 1.0d;
            int i4 = i3;
            iArr2[i4] = iArr2[i4] + i2 + iArr[i3];
            i2 += this.trainTensor.dimensions[i3];
        }
        return new SparseVector(i, iArr2, dArr);
    }
}
