package cc.mallet.fst;

import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/ThreadedOptimizable.class */
public class ThreadedOptimizable implements Optimizable.ByGradientValue {
    private static Logger logger;
    protected InstanceList trainingSet;
    protected Optimizable.ByCombiningBatchGradient optimizable;
    protected double[] batchCachedValue;
    protected List<double[]> batchCachedGradient;
    protected CacheStaleIndicator cacheIndicator;
    private transient Collection<Callable<Double>> valueTasks;
    private transient Collection<Callable<Boolean>> gradientTasks;
    private transient ThreadPoolExecutor executor;
    public static final int SLEEP_TIME = 100;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/fst/ThreadedOptimizable$GradientHandler.class */
    public class GradientHandler implements Callable<Boolean> {
        private int batchIndex;
        private int[] batchAssignments;

        public GradientHandler(int i, int[] iArr) {
            this.batchIndex = i;
            this.batchAssignments = iArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Boolean call() {
            ThreadedOptimizable.this.optimizable.getBatchValueGradient(ThreadedOptimizable.this.batchCachedGradient.get(this.batchIndex), this.batchIndex, this.batchAssignments);
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/fst/ThreadedOptimizable$ValueHandler.class */
    public class ValueHandler implements Callable<Double> {
        private int batchIndex;
        private int[] batchAssignments;

        public ValueHandler(int i, int[] iArr) {
            this.batchIndex = i;
            this.batchAssignments = iArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() {
            return Double.valueOf(ThreadedOptimizable.this.optimizable.getBatchValue(this.batchIndex, this.batchAssignments));
        }
    }

    static {
        $assertionsDisabled = !ThreadedOptimizable.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(ThreadedOptimizable.class.getName());
    }

    public ThreadedOptimizable(Optimizable.ByCombiningBatchGradient byCombiningBatchGradient, InstanceList instanceList, int i, CacheStaleIndicator cacheStaleIndicator) {
        this.trainingSet = instanceList;
        this.optimizable = byCombiningBatchGradient;
        int numBatches = byCombiningBatchGradient.getNumBatches();
        if (!$assertionsDisabled && numBatches <= 0) {
            throw new AssertionError("Invalid number of batches: " + numBatches);
        }
        this.batchCachedValue = new double[numBatches];
        this.batchCachedGradient = new ArrayList(numBatches);
        for (int i2 = 0; i2 < numBatches; i2++) {
            this.batchCachedGradient.add(new double[i]);
        }
        this.cacheIndicator = cacheStaleIndicator;
        logger.info("Creating " + numBatches + " threads for updating gradient...");
        this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numBatches);
        createTasks();
    }

    public Optimizable.ByCombiningBatchGradient getOptimizable() {
        return this.optimizable;
    }

    public void shutdown() {
        this.executor.shutdown();
        try {
            this.executor.awaitTermination(30L, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        if (!$assertionsDisabled && this.executor.shutdownNow().size() != 0) {
            throw new AssertionError("All tasks didn't finish");
        }
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (!this.cacheIndicator.isValueStale()) {
            return MatrixOps.sum(this.batchCachedValue);
        }
        try {
            int i = 0;
            Iterator it = this.executor.invokeAll(this.valueTasks).iterator();
            while (it.hasNext()) {
                try {
                    int i2 = i;
                    i++;
                    this.batchCachedValue[i2] = ((Double) ((Future) it.next()).get()).doubleValue();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
        double sum = MatrixOps.sum(this.batchCachedValue);
        logger.info("getValue() (loglikelihood, optimizable by label likelihood) =" + sum);
        return sum;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cacheIndicator.isGradientStale()) {
            getValue();
            try {
                this.executor.invokeAll(this.gradientTasks);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        this.optimizable.combineGradients(this.batchCachedGradient, dArr);
    }

    protected void createTasks() {
        int i;
        int i2;
        int numBatches = this.optimizable.getNumBatches();
        this.valueTasks = new ArrayList(numBatches);
        this.gradientTasks = new ArrayList(numBatches);
        int size = this.trainingSet.size() / numBatches;
        int i3 = -1;
        for (int i4 = 0; i4 < numBatches; i4++) {
            if (i4 == 0) {
                i = 0;
                i2 = 0 + size;
            } else if (i4 == numBatches - 1) {
                i = i3;
                i2 = this.trainingSet.size();
            } else {
                i = i3;
                i2 = i + size;
            }
            i3 = i2;
            this.valueTasks.add(new ValueHandler(i4, new int[]{i, i3}));
            this.gradientTasks.add(new GradientHandler(i4, new int[]{i, i3}));
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.optimizable.getNumParameters();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        this.optimizable.getParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.optimizable.getParameter(i);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.optimizable.setParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.optimizable.setParameter(i, d);
    }
}
