package com.github.chen0040.trees.id3;

import com.github.chen0040.data.frame.BasicDataFrame;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.CountRepository;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:com/github/chen0040/trees/id3/ID3TreeNode.class */
public class ID3TreeNode implements Cloneable {
    private int rowCount;
    private int splitAttributeIndex;
    private String attributeValue;
    private String classLabel;
    private final List<ID3TreeNode> childNodes = new ArrayList();
    private final List<String> columns = new ArrayList();

    public void copy(ID3TreeNode iD3TreeNode) {
        this.rowCount = iD3TreeNode.rowCount;
        this.splitAttributeIndex = iD3TreeNode.splitAttributeIndex;
        this.attributeValue = iD3TreeNode.attributeValue;
        this.childNodes.clear();
        for (int i = 0; i < iD3TreeNode.childNodes.size(); i++) {
            this.childNodes.add((ID3TreeNode) iD3TreeNode.childNodes.get(i).clone());
        }
        this.classLabel = iD3TreeNode.classLabel;
    }

    public Object clone() {
        ID3TreeNode iD3TreeNode = new ID3TreeNode();
        iD3TreeNode.copy(this);
        return iD3TreeNode;
    }

    public ID3TreeNode() {
    }

    public ID3TreeNode(DataFrame dataFrame, Random random, int i, int i2, List<String> list) {
        this.columns.addAll(list);
        this.rowCount = dataFrame.rowCount();
        this.splitAttributeIndex = -1;
        this.attributeValue = "";
        this.classLabel = "";
        updateClassLabel(dataFrame);
        if (this.rowCount <= 1 || i == i2) {
            return;
        }
        int size = list.size();
        CountRepository[] countRepositoryArr = new CountRepository[size];
        CountRepository countRepository = new CountRepository();
        for (int i3 = 0; i3 < size; i3++) {
            countRepositoryArr[i3] = new CountRepository(String.format("field%d", Integer.valueOf(i3)));
        }
        for (int i4 = 0; i4 < this.rowCount; i4++) {
            String str = "ClassLabel=" + dataFrame.row(i4).categoricalTarget();
            for (int i5 = 0; i5 < size; i5++) {
                String str2 = list.get(i5);
                countRepositoryArr[i5].addSupportCount(new String[]{str2, str});
                countRepositoryArr[i5].addSupportCount(new String[]{str2});
                countRepositoryArr[i5].addSupportCount(new String[0]);
            }
            countRepository.addSupportCount(new String[]{str});
            countRepository.addSupportCount(new String[0]);
        }
        double d = 0.0d;
        Iterator it = countRepository.getSubEventNames(new String[0]).iterator();
        while (it.hasNext()) {
            double probability = countRepository.getProbability((String) it.next());
            d += (-probability) * log2(probability);
        }
        if (d == 0.0d) {
            return;
        }
        this.splitAttributeIndex = -1;
        HashMap hashMap = new HashMap();
        for (int i6 = 0; i6 < size; i6++) {
            List subEventNames = countRepositoryArr[i6].getSubEventNames(new String[0]);
            double d2 = 0.0d;
            for (int i7 = 0; i7 < subEventNames.size(); i7++) {
                String str3 = (String) subEventNames.get(i7);
                double probability2 = countRepositoryArr[i6].getProbability(str3);
                List subEventNames2 = countRepositoryArr[i6].getSubEventNames(new String[]{str3});
                double d3 = 0.0d;
                for (int i8 = 0; i8 < subEventNames2.size(); i8++) {
                    double conditionalProbability = countRepositoryArr[i6].getConditionalProbability((String) subEventNames.get(i7), (String) subEventNames2.get(i8));
                    d3 += (-conditionalProbability) * log2(conditionalProbability);
                }
                d2 += probability2 * d3;
            }
            double d4 = d - d2;
            if (d4 > 0.0d) {
                hashMap.put(Integer.valueOf(i6), Double.valueOf(d4));
            }
        }
        if (hashMap.isEmpty()) {
            return;
        }
        double d5 = 0.0d;
        for (Integer num : hashMap.keySet()) {
            double doubleValue = ((Double) hashMap.get(num)).doubleValue();
            if (doubleValue > d5) {
                d5 = doubleValue;
                this.splitAttributeIndex = num.intValue();
            }
        }
        List subEventNames3 = countRepositoryArr[this.splitAttributeIndex].getSubEventNames(new String[0]);
        DataFrame[] dataFrameArr = new DataFrame[subEventNames3.size()];
        for (int i9 = 0; i9 < dataFrameArr.length; i9++) {
            dataFrameArr[i9] = new BasicDataFrame();
        }
        for (int i10 = 0; i10 < this.rowCount; i10++) {
            dataFrameArr[subEventNames3.indexOf(list.get(this.splitAttributeIndex))].addRow(dataFrame.row(i10));
        }
        for (int i11 = 0; i11 < dataFrameArr.length; i11++) {
            dataFrameArr[i11].lock();
            this.childNodes.add(new ID3TreeNode(dataFrameArr[i11], random, i + 1, i2, list));
            this.childNodes.get(i11).attributeValue = (String) subEventNames3.get(i11);
        }
    }

    public static double heuristicCost(double d) {
        if (d <= 1.0d) {
            return 0.0d;
        }
        return (2.0d * (Math.log(d - 1.0d) + 0.5772156649d)) - ((2.0d * (d - 1.0d)) / d);
    }

    private double log2(double d) {
        return Math.log(d) / Math.log(2.0d);
    }

    private void updateClassLabel(DataFrame dataFrame) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < dataFrame.rowCount(); i++) {
            String categoricalTarget = dataFrame.row(i).categoricalTarget();
            hashMap.put(categoricalTarget, Integer.valueOf(hashMap.containsKey(categoricalTarget) ? ((Integer) hashMap.get(categoricalTarget)).intValue() + 1 : 1));
        }
        int i2 = 0;
        for (String str : hashMap.keySet()) {
            if (((Integer) hashMap.get(str)).intValue() > i2) {
                i2 = ((Integer) hashMap.get(str)).intValue();
                this.classLabel = str;
            }
        }
    }

    public String predict(DataRow dataRow) {
        if (!this.childNodes.isEmpty()) {
            String str = this.columns.get(this.splitAttributeIndex);
            for (ID3TreeNode iD3TreeNode : this.childNodes) {
                if (iD3TreeNode.attributeValue.equals(str)) {
                    return iD3TreeNode.predict(dataRow);
                }
            }
        }
        return this.classLabel;
    }

    protected double pathLength(DataRow dataRow) {
        if (!this.childNodes.isEmpty()) {
            String str = this.columns.get(this.splitAttributeIndex);
            for (ID3TreeNode iD3TreeNode : this.childNodes) {
                if (iD3TreeNode.attributeValue.equals(str)) {
                    return iD3TreeNode.pathLength(dataRow) + 1.0d;
                }
            }
        }
        return heuristicCost(this.rowCount);
    }
}
