package de.jungblut.classification.tree;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.classification.meta.Voter;
import de.jungblut.math.DoubleVector;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;

/* loaded from: input_file:de/jungblut/classification/tree/RandomForest.class */
public final class RandomForest extends AbstractClassifier {
    private final int numTrees;
    private FeatureType[] featureTypes;
    private int numThreads;
    private int numRandomFeaturesToChoose;
    private int maxHeight;
    private boolean verbose;
    private boolean compile;
    private Voter<DecisionTree> trees;

    /* loaded from: input_file:de/jungblut/classification/tree/RandomForest$DecisionTreeFactory.class */
    private final class DecisionTreeFactory implements ClassifierFactory<DecisionTree> {
        private DecisionTreeFactory() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // de.jungblut.classification.ClassifierFactory
        public DecisionTree newInstance() {
            return RandomForest.this.compile ? DecisionTree.createCompiledTree(RandomForest.this.featureTypes).setNumRandomFeaturesToChoose(RandomForest.this.numRandomFeaturesToChoose).setMaxHeight(RandomForest.this.maxHeight) : DecisionTree.create(RandomForest.this.featureTypes).setNumRandomFeaturesToChoose(RandomForest.this.numRandomFeaturesToChoose).setMaxHeight(RandomForest.this.maxHeight);
        }
    }

    private RandomForest(int i) {
        this.numThreads = 1;
        this.numRandomFeaturesToChoose = 0;
        this.maxHeight = Integer.MAX_VALUE;
        this.compile = false;
        this.numTrees = i;
    }

    private RandomForest(int i, Voter<DecisionTree> voter) {
        this(i);
        this.trees = voter;
    }

    @Override // de.jungblut.classification.AbstractClassifier, de.jungblut.classification.Classifier
    public void train(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        Preconditions.checkArgument(doubleVectorArr.length == doubleVectorArr2.length, "Number of examples and outcomes must match!");
        Preconditions.checkArgument(this.numTrees > 1, "There must be at least two trees to make up a forest!");
        if (this.featureTypes == null) {
            this.featureTypes = new FeatureType[doubleVectorArr[0].getDimension()];
            Arrays.fill(this.featureTypes, FeatureType.NOMINAL);
        }
        int dimension = doubleVectorArr[0].getDimension();
        if (this.numRandomFeaturesToChoose <= 0) {
            this.numRandomFeaturesToChoose = (int) Math.sqrt(dimension);
        }
        Preconditions.checkArgument(this.featureTypes.length == dimension, "FeatureType length must match the dimension of the features! Given: " + dimension + ", but expected: " + this.featureTypes.length);
        Preconditions.checkArgument(this.numRandomFeaturesToChoose < dimension, "Number of random features to choose must be lower or equal than the number of features!");
        this.trees = Voter.create(this.numTrees, Voter.CombiningType.MAJORITY, new DecisionTreeFactory()).selectionType(Voter.SelectionType.BAGGING).numThreads(this.numThreads).verbose(this.verbose);
        this.trees.train(doubleVectorArr, doubleVectorArr2);
    }

    @Override // de.jungblut.classification.Predictor
    public DoubleVector predict(DoubleVector doubleVector) {
        this.trees.setCombiningType(Voter.CombiningType.MAJORITY);
        return this.trees.predict(doubleVector);
    }

    @Override // de.jungblut.classification.AbstractPredictor, de.jungblut.classification.Predictor
    public DoubleVector predictProbability(DoubleVector doubleVector) {
        this.trees.setCombiningType(Voter.CombiningType.PROBABILITY);
        return this.trees.predict(doubleVector);
    }

    public RandomForest compile() {
        this.compile = true;
        return this;
    }

    public RandomForest verbose() {
        return verbose(true);
    }

    public RandomForest verbose(boolean z) {
        this.verbose = z;
        return this;
    }

    public RandomForest setMaxHeight(int i) {
        this.maxHeight = i;
        return this;
    }

    public RandomForest numThreads(int i) {
        this.numThreads = i;
        return this;
    }

    public RandomForest setNumRandomFeaturesToChoose(int i) {
        this.numRandomFeaturesToChoose = i;
        return this;
    }

    public RandomForest setFeatureTypes(FeatureType[] featureTypeArr) {
        this.featureTypes = featureTypeArr;
        return this;
    }

    public static RandomForest create(int i) {
        return new RandomForest(i);
    }

    public static RandomForest create(int i, FeatureType[] featureTypeArr) {
        return new RandomForest(i).setFeatureTypes(featureTypeArr);
    }

    public static void serialize(RandomForest randomForest, DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(randomForest.numTrees);
        for (Classifier classifier : randomForest.trees.getClassifier()) {
            DecisionTree.serialize((DecisionTree) classifier, dataOutput);
        }
    }

    public static RandomForest deserialize(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        Voter create = Voter.create(readInt, Voter.CombiningType.MAJORITY, new ClassifierFactory<DecisionTree>() { // from class: de.jungblut.classification.tree.RandomForest.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // de.jungblut.classification.ClassifierFactory
            public DecisionTree newInstance() {
                return null;
            }
        });
        for (int i = 0; i < readInt; i++) {
            create.getClassifier()[i] = DecisionTree.deserialize(dataInput);
        }
        return new RandomForest(readInt, create);
    }
}
