package com.github.cschen1205.falcon;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:com/github/cschen1205/falcon/Falcon.class */
public class Falcon {
    public ArrayList<FalconNode> nodes = new ArrayList<>();
    public FalconConfig config;

    public Falcon(FalconConfig falconConfig) {
        this.config = falconConfig;
    }

    public int numReward() {
        return this.config.numReward;
    }

    public int numAction() {
        return this.config.numAction;
    }

    public int learn(double[] dArr, int i, double d) {
        double[] dArr2 = new double[numReward()];
        dArr2[0] = clamp(d, 0.0d, 1.0d);
        dArr2[1] = 1.0d - dArr2[0];
        return learn(dArr, i, dArr2);
    }

    public int learn(double[] dArr, double[] dArr2, double d) {
        double[] dArr3 = new double[numReward()];
        dArr3[0] = clamp(d, 0.0d, 1.0d);
        dArr3[1] = 1.0d - dArr3[0];
        return learn(dArr, dArr2, dArr3);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double clamp(double d, double d2, double d3) {
        if (d < d2) {
            d = d2;
        } else if (d > d3) {
            d = d3;
        }
        return d;
    }

    public int learn(double[] dArr, int i, double[] dArr2) {
        double[] dArr3 = new double[numAction()];
        for (int i2 = 0; i2 < numAction(); i2++) {
            dArr3[i2] = 0.0d;
        }
        dArr3[i] = 1.0d;
        return learn(dArr, dArr3, dArr2);
    }

    public int learn(double[] dArr, double[] dArr2, double[] dArr3) {
        return learn(dArr, dArr2, dArr3, computeChoiceValues(this.nodes, dArr, dArr2, null, this.config));
    }

    public int learn(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double[] dArr5 = {this.config.rho_inputs, this.config.rho_actions, this.config.rho_rewards};
        int i = -1;
        boolean z = true;
        int i2 = 0;
        while (true) {
            if (i2 >= this.nodes.size()) {
                break;
            }
            i = compete(dArr4);
            if (i == -1 || i == -1) {
                break;
            }
            FalconNode falconNode = this.nodes.get(i);
            if (falconNode.isVigilanceConstraintSatisfied(dArr, dArr2, dArr3, dArr5)) {
                falconNode.learnTemplate(dArr, dArr2, dArr3, this.config);
                z = false;
                break;
            }
            if (falconNode.isPerfectMismatch(dArr)) {
                falconNode.overwrite(dArr, dArr2, dArr3, this.config);
                z = false;
                break;
            }
            dArr4[i] = -1.0d;
            dArr5 = falconNode.raiseVigilance(dArr, dArr2, dArr3, dArr5, this.config);
            i2++;
        }
        if (z) {
            FalconNode falconNode2 = new FalconNode(dArr, dArr2, dArr3);
            this.nodes.add(falconNode2);
            onNewNode(falconNode2);
            i = this.nodes.size() - 1;
        }
        return i;
    }

    protected void onNewNode(FalconNode falconNode) {
    }

    public int selectActionId(double[] dArr, QValueProvider qValueProvider) {
        return selectDirectionActionId(dArr);
    }

    public int selectActionId(double[] dArr) {
        return selectActionId(dArr, (QValueProvider) null);
    }

    public int selectActionId(double[] dArr, Set<Integer> set, QValueProvider qValueProvider) {
        return selectDirectionActionId(dArr, set);
    }

    public int selectActionId(double[] dArr, Set<Integer> set) {
        return selectActionId(dArr, set, null);
    }

    public double[] searchAction(double[] dArr) {
        int compete = compete(computeChoiceValues(this.nodes, dArr, null, null, this.config));
        if (compete != -1) {
            return (double[]) this.nodes.get(compete).weight_actions.clone();
        }
        return null;
    }

    public int compete(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        onChoiceCompeted(i);
        return i;
    }

    protected void onChoiceCompeted(int i) {
    }

    public static double[] computeChoiceValues(List<FalconNode> list, double[] dArr, double[] dArr2, double[] dArr3, FalconConfig falconConfig) {
        int size = list.size();
        double[] dArr4 = new double[size];
        for (int i = 0; i < size; i++) {
            dArr4[i] = list.get(i).computeChoiceValue(dArr, dArr2, dArr3, falconConfig);
        }
        return dArr4;
    }

    public int selectDirectionActionId(double[] dArr, Set<Integer> set) {
        double[] searchAction = searchAction(dArr);
        if (searchAction == null) {
            ArrayList arrayList = new ArrayList(set);
            if (arrayList.size() > 0) {
                return ((Integer) arrayList.get(((int) Math.random()) * arrayList.size())).intValue();
            }
        }
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (Integer num : set) {
            if (searchAction[num.intValue()] > d) {
                d = searchAction[num.intValue()];
                i = num.intValue();
            }
        }
        return i;
    }

    public int selectDirectionActionId(double[] dArr) {
        double[] searchAction = searchAction(dArr);
        int numAction = numAction();
        if (searchAction == null) {
            return ((int) Math.random()) * numAction;
        }
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < numAction; i2++) {
            if (searchAction[i2] > d) {
                d = searchAction[i2];
                i = i2;
            }
        }
        return i;
    }

    public FalconConfig getConfig() {
        return this.config;
    }
}
