package de.jungblut.math.minimize;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.partition.Boundaries;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

/* loaded from: input_file:de/jungblut/math/minimize/AbstractMiniBatchCostFunction.class */
public abstract class AbstractMiniBatchCostFunction implements CostFunction {
    private final Executor pool;
    private final List<Tuple<DoubleMatrix, DoubleMatrix>> batches;
    private final boolean stochastic;
    private int batchOffset;

    /* loaded from: input_file:de/jungblut/math/minimize/AbstractMiniBatchCostFunction$CallableMiniBatch.class */
    class CallableMiniBatch implements Callable<CostGradientTuple> {
        private final DoubleVector parameters;
        private final Tuple<DoubleMatrix, DoubleMatrix> featureOutcome;

        public CallableMiniBatch(Tuple<DoubleMatrix, DoubleMatrix> tuple, DoubleVector doubleVector) {
            this.featureOutcome = tuple;
            this.parameters = doubleVector;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public CostGradientTuple call() throws Exception {
            return AbstractMiniBatchCostFunction.this.evaluateBatch(this.parameters, (DoubleMatrix) this.featureOutcome.getFirst(), (DoubleMatrix) this.featureOutcome.getSecond());
        }
    }

    public AbstractMiniBatchCostFunction(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i, int i2) {
        this(doubleVectorArr, doubleVectorArr2, i, i2, false);
    }

    public AbstractMiniBatchCostFunction(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i, int i2, boolean z) {
        this.batchOffset = 0;
        Preconditions.checkArgument(i >= 0 && i <= doubleVectorArr.length, "Batchsize wasn't in range of 0-" + doubleVectorArr.length);
        Preconditions.checkArgument(i2 >= 1, "#Threads need to be at least > 0");
        this.stochastic = z;
        ThreadFactory build = new ThreadFactoryBuilder().setDaemon(true).setNameFormat("MiniBatch Worker %d").build();
        HashSet<Boundaries.Range> hashSet = new HashSet();
        if (i != 0) {
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 >= doubleVectorArr.length) {
                    break;
                }
                hashSet.add(new Boundaries.Range(i4, Math.min(doubleVectorArr.length - 1, i4 + (i - 1))));
                i3 = i4 + i;
            }
        } else {
            i2 = 1;
            hashSet.add(new Boundaries.Range(0, doubleVectorArr.length - 1));
        }
        this.pool = Executors.newFixedThreadPool(z ? 1 : i2, build);
        this.batches = new ArrayList();
        for (Boundaries.Range range : hashSet) {
            int start = range.getStart();
            int end = range.getEnd();
            DoubleVector[] doubleVectorArr3 = (DoubleVector[]) ArrayUtils.subArray(doubleVectorArr, start, end);
            boolean isSparse = doubleVectorArr3[0].isSparse();
            DenseDoubleMatrix denseDoubleMatrix = doubleVectorArr2 != null ? new DenseDoubleMatrix((DoubleVector[]) ArrayUtils.subArray(doubleVectorArr2, start, end)) : null;
            DenseDoubleVector ones = DenseDoubleVector.ones(doubleVectorArr3.length);
            SparseDoubleRowMatrix sparseDoubleRowMatrix = isSparse ? new SparseDoubleRowMatrix(doubleVectorArr3) : new DenseDoubleMatrix(doubleVectorArr3);
            this.batches.add(new Tuple<>(isSparse ? new SparseDoubleRowMatrix(ones, sparseDoubleRowMatrix) : new DenseDoubleMatrix(ones, sparseDoubleRowMatrix), denseDoubleMatrix));
        }
    }

    @Override // de.jungblut.math.minimize.CostFunction
    public final CostGradientTuple evaluateCost(DoubleVector doubleVector) {
        if (this.batches.size() == 1) {
            try {
                return new CallableMiniBatch(this.batches.get(0), doubleVector).call();
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(this.pool);
        int i = 0;
        int i2 = this.batchOffset;
        while (true) {
            if (i2 >= this.batches.size()) {
                break;
            }
            executorCompletionService.submit(new CallableMiniBatch(this.batches.get(i2), doubleVector));
            i++;
            if (this.stochastic) {
                this.batchOffset++;
                if (this.batchOffset >= this.batches.size()) {
                    this.batchOffset = 0;
                }
            } else {
                i2++;
            }
        }
        double d = 0.0d;
        DoubleVector denseDoubleVector = new DenseDoubleVector(doubleVector.getLength());
        for (int i3 = 0; i3 < i; i3++) {
            try {
                CostGradientTuple costGradientTuple = (CostGradientTuple) executorCompletionService.take().get();
                d += costGradientTuple.getCost();
                denseDoubleVector = denseDoubleVector.add(costGradientTuple.getGradient());
            } catch (InterruptedException | ExecutionException e2) {
                e2.printStackTrace();
                return null;
            }
        }
        return new CostGradientTuple(d / i, denseDoubleVector.divide(i));
    }

    protected abstract CostGradientTuple evaluateBatch(DoubleVector doubleVector, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2);
}
