package de.jungblut.classification.meta;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.partition.BlockPartitioner;
import de.jungblut.partition.Boundaries;
import de.jungblut.writable.MatrixWritable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/classification/meta/Voter.class */
public final class Voter<A extends Classifier> extends AbstractClassifier {
    private static final Logger LOG = LogManager.getLogger(Voter.class);
    private final Classifier[] classifier;
    private CombiningType type;
    private SelectionType selection = SelectionType.NONE;
    private int threads = 1;
    private boolean verbose;

    /* renamed from: de.jungblut.classification.meta.Voter$1, reason: invalid class name */
    /* loaded from: input_file:de/jungblut/classification/meta/Voter$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$de$jungblut$classification$meta$Voter$SelectionType;
        static final /* synthetic */ int[] $SwitchMap$de$jungblut$classification$meta$Voter$CombiningType = new int[CombiningType.values().length];

        static {
            try {
                $SwitchMap$de$jungblut$classification$meta$Voter$CombiningType[CombiningType.MAJORITY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$de$jungblut$classification$meta$Voter$CombiningType[CombiningType.PROBABILITY.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$de$jungblut$classification$meta$Voter$CombiningType[CombiningType.AVERAGE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$de$jungblut$classification$meta$Voter$SelectionType = new int[SelectionType.values().length];
            try {
                $SwitchMap$de$jungblut$classification$meta$Voter$SelectionType[SelectionType.BAGGING.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$de$jungblut$classification$meta$Voter$SelectionType[SelectionType.SHUFFLE.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* loaded from: input_file:de/jungblut/classification/meta/Voter$CombiningType.class */
    public enum CombiningType {
        MAJORITY,
        AVERAGE,
        PROBABILITY
    }

    /* loaded from: input_file:de/jungblut/classification/meta/Voter$SelectionType.class */
    public enum SelectionType {
        NONE,
        SHUFFLE,
        BAGGING
    }

    /* loaded from: input_file:de/jungblut/classification/meta/Voter$TrainingWorker.class */
    final class TrainingWorker implements Callable<Boolean> {
        private final Classifier cls;
        private final TrainingSplit split;

        TrainingWorker(Classifier classifier, TrainingSplit trainingSplit) {
            this.cls = classifier;
            this.split = trainingSplit;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Boolean call() throws Exception {
            this.cls.train(this.split.getTrainFeatures(), this.split.getTrainOutcome());
            return true;
        }
    }

    private Voter(CombiningType combiningType, int i, ClassifierFactory<A> classifierFactory) {
        this.type = combiningType;
        this.classifier = new Classifier[i];
        for (int i2 = 0; i2 < i; i2++) {
            this.classifier[i2] = classifierFactory.newInstance();
        }
    }

    private Voter(List<A> list) {
        this.classifier = new Classifier[list.size()];
        for (int i = 0; i < this.classifier.length; i++) {
            this.classifier[i] = (Classifier) Preconditions.checkNotNull(list.get(i));
        }
    }

    @Override // de.jungblut.classification.AbstractClassifier, de.jungblut.classification.Classifier
    public void train(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        List<TrainingSplit> partition;
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.threads);
        try {
            try {
                ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(newFixedThreadPool);
                switch (AnonymousClass1.$SwitchMap$de$jungblut$classification$meta$Voter$SelectionType[this.selection.ordinal()]) {
                    case MatrixWritable.DENSE_DOUBLE_MATRIX /* 1 */:
                        partition = bag(doubleVectorArr, doubleVectorArr2);
                        break;
                    case MatrixWritable.SPARSE_DOUBLE_ROW_MATRIX /* 2 */:
                        partition = partition(doubleVectorArr, doubleVectorArr2, true);
                        break;
                    default:
                        partition = partition(doubleVectorArr, doubleVectorArr2, false);
                        break;
                }
                for (int i = 0; i < this.classifier.length; i++) {
                    executorCompletionService.submit(new TrainingWorker(this.classifier[i], partition.get(i)));
                }
                for (int i2 = 0; i2 < this.classifier.length; i2++) {
                    executorCompletionService.take();
                    if (this.verbose) {
                        LOG.info("Finished with training classifier " + (i2 + 1) + " of " + this.classifier.length);
                    }
                }
            } catch (InterruptedException e) {
                e.printStackTrace();
                newFixedThreadPool.shutdownNow();
            }
            if (this.verbose) {
                LOG.info("Successfully finished training!");
            }
        } finally {
            newFixedThreadPool.shutdownNow();
        }
    }

    @Override // de.jungblut.classification.Predictor
    public DoubleVector predict(DoubleVector doubleVector) {
        DoubleVector[] doubleVectorArr = new DoubleVector[this.classifier.length];
        for (int i = 0; i < this.classifier.length; i++) {
            doubleVectorArr[i] = this.classifier[i].predict(doubleVector);
        }
        int dimension = doubleVectorArr[0].getDimension() == 1 ? 2 : doubleVectorArr[0].getDimension();
        DoubleVector denseDoubleVector = new DenseDoubleVector(doubleVectorArr[0].getDimension() == 1 ? 1 : dimension);
        switch (AnonymousClass1.$SwitchMap$de$jungblut$classification$meta$Voter$CombiningType[this.type.ordinal()]) {
            case MatrixWritable.DENSE_DOUBLE_MATRIX /* 1 */:
                double[] createPredictionHistogram = createPredictionHistogram(doubleVectorArr, dimension);
                if (dimension == 2) {
                    denseDoubleVector.set(0, ArrayUtils.maxIndex(createPredictionHistogram));
                    break;
                } else {
                    denseDoubleVector.set(ArrayUtils.maxIndex(createPredictionHistogram), 1.0d);
                    break;
                }
            case MatrixWritable.SPARSE_DOUBLE_ROW_MATRIX /* 2 */:
                DoubleVector doubleVector2 = doubleVectorArr[0];
                for (int i2 = 1; i2 < doubleVectorArr.length; i2++) {
                    doubleVector2 = doubleVector2.add(doubleVectorArr[i2]);
                }
                denseDoubleVector = doubleVector2.divide(doubleVector2.sum());
                break;
            case 3:
                for (DoubleVector doubleVector3 : doubleVectorArr) {
                    denseDoubleVector = denseDoubleVector.add(doubleVector3);
                }
                denseDoubleVector = denseDoubleVector.divide(this.classifier.length);
                break;
            default:
                throw new UnsupportedOperationException("Type " + this.type + " isn't supported yet!");
        }
        return denseDoubleVector;
    }

    public Classifier[] getClassifier() {
        return this.classifier;
    }

    public Voter<A> verbose() {
        return verbose(true);
    }

    public Voter<A> verbose(boolean z) {
        this.verbose = z;
        return this;
    }

    public Voter<A> selectionType(SelectionType selectionType) {
        this.selection = selectionType;
        return this;
    }

    public Voter<A> numThreads(int i) {
        this.threads = i;
        return this;
    }

    public Voter<A> setCombiningType(CombiningType combiningType) {
        this.type = combiningType;
        return this;
    }

    private double[] createPredictionHistogram(DoubleVector[] doubleVectorArr, int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < this.classifier.length; i2++) {
            int extractPredictedClass = this.classifier[i2].extractPredictedClass(doubleVectorArr[i2]);
            dArr[extractPredictedClass] = dArr[extractPredictedClass] + 1.0d;
        }
        return dArr;
    }

    private List<TrainingSplit> bag(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        ArrayList arrayList = new ArrayList(this.classifier.length);
        Random random = new Random();
        for (int i = 0; i < this.classifier.length; i++) {
            DoubleVector[] doubleVectorArr3 = new DoubleVector[doubleVectorArr.length];
            DoubleVector[] doubleVectorArr4 = new DoubleVector[doubleVectorArr.length];
            for (int i2 = 0; i2 < doubleVectorArr.length; i2++) {
                int nextInt = random.nextInt(doubleVectorArr.length);
                doubleVectorArr3[i2] = doubleVectorArr[nextInt];
                doubleVectorArr4[i2] = doubleVectorArr2[nextInt];
            }
            arrayList.add(new TrainingSplit(doubleVectorArr3, doubleVectorArr4));
        }
        return arrayList;
    }

    /* JADX WARN: Type inference failed for: r1v26, types: [de.jungblut.math.DoubleVector[], java.lang.Object[][]] */
    private List<TrainingSplit> partition(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, boolean z) {
        ArrayList arrayList = new ArrayList(this.classifier.length);
        if (z) {
            ArrayUtils.multiShuffle(doubleVectorArr, new DoubleVector[]{doubleVectorArr2});
        }
        ArrayList arrayList2 = new ArrayList(new BlockPartitioner().partition(this.classifier.length, doubleVectorArr.length).getBoundaries());
        int[] iArr = new int[this.classifier.length + 1];
        for (int i = 1; i < this.classifier.length; i++) {
            iArr[i] = ((Boundaries.Range) arrayList2.get(i)).getStart();
        }
        iArr[this.classifier.length] = doubleVectorArr.length - 1;
        if (this.verbose) {
            LOG.info("Computed split ranges for 0-" + doubleVectorArr.length + ": " + Arrays.toString(iArr) + "\n");
        }
        for (int i2 = 0; i2 < this.classifier.length; i2++) {
            arrayList.add(new TrainingSplit((DoubleVector[]) ArrayUtils.subArray(doubleVectorArr, iArr[i2], iArr[i2 + 1]), (DoubleVector[]) ArrayUtils.subArray(doubleVectorArr2, iArr[i2], iArr[i2 + 1])));
        }
        return arrayList;
    }

    public static <K extends Classifier> Voter<K> create(int i, CombiningType combiningType, ClassifierFactory<K> classifierFactory) {
        return new Voter<>(combiningType, i, classifierFactory);
    }

    public static <K extends Classifier> Voter<K> fromTrainedModels(List<K> list) {
        return new Voter<>(list);
    }
}
