package stream.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import stream.Data;
import stream.learner.HyperplaneModel;
import stream.learner.LearnerUtils;

/* loaded from: input_file:stream/classifier/Perceptron.class */
public class Perceptron extends AbstractClassifier {
    private static final long serialVersionUID = -3263838547557335984L;
    static Logger log = LoggerFactory.getLogger(Perceptron.class);
    Double learnRate;
    int kernelType;
    List<String> labels;
    List<String> attributes;
    private HyperplaneModel model;

    public Perceptron() {
        this(1, 0.05d);
    }

    public Perceptron(int i) {
        this(1, i);
    }

    public Perceptron(int i, double d) {
        this.labels = new ArrayList();
        this.attributes = new ArrayList();
        this.kernelType = i;
        this.model = new HyperplaneModel(i);
        this.model.initModel(new LinkedHashMap(), Double.valueOf(0.0d));
        this.learnRate = Double.valueOf(d);
    }

    public Double getLearnRate() {
        return this.learnRate;
    }

    public void setLearnRate(Double d) {
        this.learnRate = d;
    }

    @Override // stream.classifier.AbstractClassifier
    public void train(Data data) {
        if (this.label == null) {
            this.label = LearnerUtils.detectLabelAttribute(data);
        }
        if (this.label == null) {
            log.error("No label found for example!");
        }
        Serializable serializable = (Serializable) data.get(this.label);
        if (serializable == null) {
            log.error("No label found for example!");
            return;
        }
        String obj = serializable.toString();
        int indexOf = this.labels.indexOf(obj);
        if (indexOf < 0 && this.labels.size() < 2) {
            log.info("Adding class '{}'", obj);
            this.labels.add(obj);
            indexOf = this.labels.indexOf(obj);
        }
        if (indexOf < 0) {
            log.error("My labels are {}, unknown label: {}", this.labels, obj);
            if (this.labels.size() == 2) {
                log.error("The perceptron algorithm only works for binary classification tasks!");
                return;
            }
            return;
        }
        Map<String, Double> numericVector = LearnerUtils.getNumericVector(data);
        if (numericVector.isEmpty()) {
            log.info("No numerical attributes found for learning! Ignoring example!");
            return;
        }
        for (String str : numericVector.keySet()) {
            if (!this.attributes.contains(str)) {
                this.attributes.add(str);
            }
        }
        Double predict = this.model.predict(data);
        if (predict == null || predict.intValue() == indexOf) {
            return;
        }
        double d = indexOf == 0 ? -1.0d : 1.0d;
        this.model.setBias(this.model.getBias() + (this.learnRate.doubleValue() * d));
        Map<String, Double> weights = this.model.getWeights();
        for (String str2 : this.attributes) {
            Double d2 = numericVector.get(str2);
            Double d3 = weights.get(str2);
            if (d3 == null) {
                d3 = Double.valueOf(0.0d);
            }
            Double valueOf = Double.valueOf(d3.doubleValue() + (this.learnRate.doubleValue() * d * d2.doubleValue()));
            if (valueOf.doubleValue() != 0.0d) {
                weights.put(str2, valueOf);
            }
        }
        this.model.setWeights(weights);
    }

    /* renamed from: predict, reason: merged with bridge method [inline-methods] */
    public String m3predict(Data data) {
        if (this.labels.isEmpty()) {
            log.warn("No labels available, predicting '?'!");
            return null;
        }
        if (this.labels.size() != 1) {
            return this.model.predict(data).doubleValue() < 0.5d ? this.labels.get(0) : this.labels.get(1);
        }
        log.warn("Only 1 label available, predicting '{}'!", this.labels.get(0));
        return this.labels.get(0);
    }

    @Override // stream.classifier.AbstractClassifier
    public void reset() throws Exception {
        this.model = new HyperplaneModel(this.kernelType);
        this.model.initModel(new LinkedHashMap(), Double.valueOf(0.0d));
    }
}
