package stream.classifier;

import java.io.Serializable;
import java.rmi.RemoteException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import stream.Data;
import stream.annotations.Description;
import stream.annotations.Parameter;
import stream.distribution.Distribution;
import stream.distribution.NominalDistribution;
import stream.distribution.NumericalDistribution;

@Description(group = "Data Stream.Mining.Classifier")
/* loaded from: input_file:stream/classifier/NaiveBayes.class */
public class NaiveBayes extends AbstractClassifier {
    private static final long serialVersionUID = 1095437834368310484L;
    static Logger log = LoggerFactory.getLogger(NaiveBayes.class);
    protected NominalDistribution<String> classDistribution;
    Double laplaceCorrection = Double.valueOf(1.0E-4d);
    Double confidenceGap = new Double(0.0d);
    Boolean wop = false;
    protected Map<String, Distribution<?>> distributions = new HashMap();

    public NaiveBayes() {
        this.classDistribution = null;
        this.classDistribution = createNominalDistribution();
    }

    public Double getLaplaceCorrection() {
        return this.laplaceCorrection;
    }

    @Parameter(required = false, description = "Value of the la-place correction")
    public void setLaplaceCorrection(Double d) {
        this.laplaceCorrection = d;
    }

    public Double getConfidenceGap() {
        return this.confidenceGap;
    }

    public void setConfidenceGap(Double d) {
        this.confidenceGap = d;
    }

    public Boolean getWop() {
        return this.wop;
    }

    public void setWop(Boolean bool) {
        this.wop = bool;
    }

    public Map<String, Double> vote(Data data) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        log.debug("Predicting one of these classes: {}", this.classDistribution.getElements());
        for (String str : getClassDistribution().getElements()) {
            if (this.wop.booleanValue()) {
                linkedHashMap.put(str, Double.valueOf(1.0d));
            } else {
                log.debug("class likelihood for class '" + str + "' is {} / {}", Double.valueOf(getClassDistribution().getCount(str).doubleValue()), getClassDistribution().getTotalCount());
                linkedHashMap.put(str, Double.valueOf(getClassDistribution().getHistogram().get(str).doubleValue() / getClassDistribution().getTotalCount().longValue()));
            }
        }
        Double valueOf = Double.valueOf(0.0d);
        String str2 = null;
        Double valueOf2 = Double.valueOf(0.0d);
        for (String str3 : linkedHashMap.keySet()) {
            Double d = (Double) linkedHashMap.get(str3);
            for (String str4 : data.keySet()) {
                if (!this.label.equals(str4)) {
                    Object obj = data.get(str4);
                    if (obj.getClass().equals(Double.class)) {
                        d = Double.valueOf(d.doubleValue() * this.distributions.get(str3).prob((Double) obj).doubleValue());
                    } else {
                        String nominalCondition = getNominalCondition(str4, data);
                        Double valueOf3 = Double.valueOf(((NominalDistribution) this.distributions.get(str3)).getCount(nominalCondition).doubleValue());
                        Double valueOf4 = Double.valueOf(getClassDistribution().getCount(str3).doubleValue());
                        if (valueOf3 == null || valueOf3.doubleValue() == 0.0d) {
                            valueOf3 = this.laplaceCorrection;
                            valueOf4 = Double.valueOf(valueOf4.doubleValue() + this.laplaceCorrection.doubleValue());
                        }
                        log.debug("  likelihood for {}  is  {}  |" + str3 + " ", nominalCondition, Double.valueOf(valueOf3.doubleValue() / valueOf4.doubleValue()));
                        d = Double.valueOf(d.doubleValue() * (valueOf3.doubleValue() / valueOf4.doubleValue()));
                    }
                }
            }
            linkedHashMap.put(str3, d);
            valueOf2 = Double.valueOf(valueOf2.doubleValue() + d.doubleValue());
        }
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        for (String str5 : linkedHashMap.keySet()) {
            Double valueOf5 = Double.valueOf(((Double) linkedHashMap.get(str5)).doubleValue() / valueOf2.doubleValue());
            linkedHashMap2.put(str5, valueOf5);
            log.debug("probability for {} is {}", str5, valueOf5);
            if (str2 == null || valueOf5.doubleValue() > valueOf.doubleValue()) {
                str2 = str5;
                valueOf = valueOf5;
            }
        }
        return linkedHashMap2;
    }

    /* renamed from: predict, reason: merged with bridge method [inline-methods] */
    public String m1predict(Data data) {
        Map<String, Double> vote = vote(data);
        Double valueOf = Double.valueOf(0.0d);
        String str = null;
        Double valueOf2 = Double.valueOf(0.0d);
        for (String str2 : vote.keySet()) {
            Double d = vote.get(str2);
            log.debug("probability for {} is {}", str2, d);
            if (str == null || d.doubleValue() > valueOf.doubleValue()) {
                str = str2;
                valueOf2 = valueOf != null ? Double.valueOf(1.0d - Math.abs(d.doubleValue() - valueOf.doubleValue())) : Double.valueOf(1.0d);
                valueOf = d;
            }
        }
        log.info("Predicting class {}, label is: {}, confidence-gap: " + valueOf2 + " wop=" + this.wop, str, data.get(this.label));
        return str;
    }

    public String getNominalCondition(String str, Data data) {
        return str + "='" + data.get(str) + "'";
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v51, types: [stream.distribution.Distribution] */
    @Override // stream.classifier.AbstractClassifier
    public void train(Data data) {
        if (this.label == null) {
            return;
        }
        if (data.get(this.label) == null) {
            log.warn("Not processing unlabeled data item {}", data);
            return;
        }
        String obj = ((Serializable) data.get(this.label)).toString();
        log.debug("Learning from example with label={}", obj);
        if (this.classDistribution == null) {
            this.classDistribution = createNominalDistribution();
        }
        if (log.isDebugEnabled()) {
            log.debug("Classes: {}", this.classDistribution.getElements());
            for (String str : this.classDistribution.getElements()) {
                log.debug("    {}:  {}", str, this.classDistribution.getCount(str));
            }
        }
        for (String str2 : data.keySet()) {
            if (str2.equalsIgnoreCase(this.label)) {
                this.classDistribution.update(obj);
            } else {
                Object obj2 = data.get(str2);
                if (obj2.getClass().equals(Double.class)) {
                    Double d = (Double) obj2;
                    log.debug("Handling numerical case ({}) with value  {}", obj2, d);
                    Distribution<Double> distribution = (Distribution) this.distributions.get(str2);
                    if (distribution == null) {
                        distribution = createNumericalDistribution();
                        log.debug("Creating new numerical distribution model for attribute {}", str2);
                        this.distributions.put(str2, distribution);
                    }
                    distribution.update(d);
                } else {
                    String nominalCondition = getNominalCondition(str2, data);
                    log.debug("Handling nominal case for [ {} | {} ]", nominalCondition, "class=" + obj);
                    NominalDistribution<String> nominalDistribution = (Distribution) this.distributions.get(obj);
                    if (nominalDistribution == null) {
                        nominalDistribution = createNominalDistribution();
                        log.debug("Creating new nominal distribution model for attribute {}, {}", str2, "class=" + obj);
                        this.distributions.put(obj, nominalDistribution);
                    }
                    nominalDistribution.update(nominalCondition);
                }
            }
        }
    }

    public NominalDistribution<String> getClassDistribution() {
        if (this.classDistribution == null) {
            this.classDistribution = createNominalDistribution();
        }
        return this.classDistribution;
    }

    public List<Distribution<Double>> getNumericalDistributions() {
        ArrayList arrayList = new ArrayList();
        for (Distribution<?> distribution : this.distributions.values()) {
            if (distribution instanceof NumericalDistribution) {
                arrayList.add(distribution);
            }
        }
        return arrayList;
    }

    public NominalDistribution<String> createNominalDistribution() {
        return new NominalDistribution<>();
    }

    public Distribution<Double> createNumericalDistribution() {
        return new NumericalDistribution();
    }

    @Override // stream.classifier.AbstractClassifier
    public String getName() throws RemoteException {
        return "stream.classifier.NaiveBayes";
    }

    @Override // stream.classifier.AbstractClassifier
    public void reset() throws Exception {
        this.classDistribution = createNominalDistribution();
        this.distributions.clear();
    }
}
