package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import java.util.Iterator;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.FactorizationMachineRecommender;

@ModelData({"isRanking", "fmals", "W", "V", "W0", "k"})
/* loaded from: input_file:net/librec/recommender/cf/rating/FMALSRecommender.class */
public class FMALSRecommender extends FactorizationMachineRecommender {
    private DenseMatrix Q;
    private SparseMatrix trainFeatureMatrix;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.FactorizationMachineRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.Q = new DenseMatrix(this.n, this.k);
        HashBasedTable create = HashBasedTable.create();
        for (int i = 0; i < this.n; i++) {
            int[] keys = this.trainTensor.keys(i);
            int i2 = 0;
            for (int i3 = 0; i3 < keys.length; i3++) {
                int i4 = i2 + keys[i3];
                i2 += this.trainTensor.dimensions[i3];
                create.put(Integer.valueOf(i), Integer.valueOf(i4), Double.valueOf(1.0d));
            }
        }
        this.trainFeatureMatrix = new SparseMatrix(this.n, this.p, create);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        DenseVector denseVector = new DenseVector(this.n);
        int i = 0;
        int userDimension = this.trainTensor.getUserDimension();
        int itemDimension = this.trainTensor.getItemDimension();
        Iterator<TensorEntry> it = this.trainTensor.iterator();
        while (it.hasNext()) {
            TensorEntry next = it.next();
            int[] keys = next.keys();
            SparseVector tenserKeysToFeatureVector = tenserKeysToFeatureVector(keys);
            denseVector.set(i, next.get() - predict(keys[userDimension], keys[itemDimension], tenserKeysToFeatureVector));
            for (int i2 = 0; i2 < this.k; i2++) {
                double d = 0.0d;
                Iterator<VectorEntry> it2 = tenserKeysToFeatureVector.iterator();
                while (it2.hasNext()) {
                    VectorEntry next2 = it2.next();
                    d += this.V.get(next2.index(), i2) * next2.get();
                }
                this.Q.set(i, i2, d);
            }
            i++;
        }
        for (int i3 = 0; i3 < this.numIterations; i3++) {
            this.loss = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.n; i4++) {
                d2 += (this.w0 * 1.0d * 1.0d) + (1.0d * denseVector.get(i4));
                d3 += 1.0d;
            }
            double d4 = d2 / (d3 + this.regW0);
            this.LOG.info("original:" + denseVector.sum());
            for (int i5 = 0; i5 < this.n; i5++) {
                double d5 = denseVector.get(i5);
                denseVector.set(i5, d5 + (this.w0 - d4));
                this.loss += d5 * d5;
            }
            this.w0 = d4;
            this.loss += this.regW0 * this.w0 * this.w0;
            this.LOG.info("after 0-way:" + denseVector.sum());
            for (int i6 = 0; i6 < this.p; i6++) {
                double d6 = this.W.get(i6);
                double d7 = 0.0d;
                double d8 = 0.0d;
                for (int i7 = 0; i7 < this.n; i7++) {
                    double d9 = this.trainFeatureMatrix.get(i7, i6);
                    d7 += (d6 * d9 * d9) + (d9 * denseVector.get(i7));
                    d8 += d9 * d9;
                }
                double d10 = d7 / (d8 + this.regW);
                for (int i8 = 0; i8 < this.n; i8++) {
                    denseVector.set(i8, denseVector.get(i8) + ((d6 - d10) * this.trainFeatureMatrix.get(i8, i6)));
                }
                this.W.set(i6, d10);
                this.loss += this.regW * d6 * d6;
            }
            this.LOG.info("after 1-way:" + denseVector.sum());
            for (int i9 = 0; i9 < this.k; i9++) {
                for (int i10 = 0; i10 < this.p; i10++) {
                    double d11 = this.V.get(i10, i9);
                    double d12 = 0.0d;
                    double d13 = 0.0d;
                    for (int i11 = 0; i11 < this.n; i11++) {
                        double d14 = this.trainFeatureMatrix.get(i11, i10);
                        double d15 = d14 * (this.Q.get(i11, i9) - (d11 * d14));
                        d12 += (d11 * d15 * d15) + (d15 * denseVector.get(i11));
                        d13 += d15 * d15;
                    }
                    double d16 = d12 / (d13 + this.regF);
                    for (int i12 = 0; i12 < this.n; i12++) {
                        double d17 = this.trainFeatureMatrix.get(i12, i10);
                        double d18 = this.Q.get(i12, i9);
                        double d19 = d18 + ((d16 - d11) * d17);
                        denseVector.set(i12, (denseVector.get(i12) + (d11 * (d17 * (d18 - (d11 * d17))))) - (d16 * (d17 * (d19 - (d16 * d17)))));
                        this.Q.set(i12, i9, d19);
                    }
                    this.V.set(i10, i9, d16);
                    this.loss += this.regF * d11 * d11;
                }
            }
            this.LOG.info("after 2-way:" + denseVector.sum());
            if (isConverged(i3) && this.earlyStop) {
                return;
            }
        }
    }

    @Override // net.librec.recommender.AbstractRecommender
    @Deprecated
    protected double predict(int i, int i2) throws LibrecException {
        return 0.0d;
    }
}
