package com.github.cschen1205.falcon;

import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:com/github/cschen1205/falcon/TDLambdaFalcon.class */
public class TDLambdaFalcon extends TDFalcon {
    private double lambda;
    private EligibilityTraceUpdateMode traceUpdateMode;
    private static Logger logger = Logger.getLogger(String.valueOf(TDLambdaFalcon.class));
    private boolean parallel;

    public TDLambdaFalcon(FalconConfig falconConfig) {
        super(falconConfig);
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        this.parallel = false;
    }

    public TDLambdaFalcon(FalconConfig falconConfig, TDMethod tDMethod) {
        super(falconConfig, tDMethod);
        this.lambda = 0.9d;
        this.traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
        this.parallel = false;
    }

    public EligibilityTraceUpdateMode getTraceUpdateMode() {
        return this.traceUpdateMode;
    }

    @Override // com.github.cschen1205.falcon.TDFalcon
    public int learnQ(double[] dArr, double[] dArr2, double[] dArr3, Set<Integer> set, double d, QValueProvider qValueProvider) {
        QValue Invalid = qValueProvider == null ? QValue.Invalid() : qValueProvider.queryQValue(dArr, dArr2, false);
        if (Invalid.isValid()) {
            double value = Invalid.getValue();
            double[] dArr4 = new double[numReward()];
            dArr4[0] = value;
            dArr4[1] = 1.0d - value;
            return learn(dArr, dArr2, dArr4);
        }
        Tuple2<Integer, Double> searchQNode = searchQNode(dArr, dArr2, QValue.Invalid());
        int intValue = searchQNode.getKey().intValue();
        double doubleValue = searchQNode.getValue().doubleValue();
        if (intValue == -1) {
            intValue = super.learnQ(dArr, dArr2, dArr3, set, d, qValueProvider);
        } else {
            FalconNode falconNode = this.nodes.get(intValue);
            falconNode.e += 1.0d;
            double tDError = getTDError(doubleValue, dArr3, set, d, qValueProvider);
            for (int i = 0; i < this.nodes.size(); i++) {
                FalconNode falconNode2 = this.nodes.get(i);
                double[] dArr5 = falconNode2.weight_inputs;
                double readQ = readQ(falconNode2.weight_rewards);
                double d2 = this.QAlpha * tDError * falconNode2.e;
                if (this.config.isBounded) {
                    d2 *= 1.0d - readQ;
                }
                double clamp = clamp(readQ + d2, 0.0d, 1.0d);
                double[] dArr6 = new double[numReward()];
                dArr6[0] = clamp;
                dArr6[1] = 1.0d - clamp;
                if (intValue == i) {
                    falconNode.learnTemplate(dArr5, dArr2, dArr6, this.config);
                    this.nodes.get(intValue).e *= this.QGamma * this.lambda;
                } else {
                    falconNode2.learnTemplate(dArr5, falconNode2.weight_actions, dArr6, this.config);
                    Tuple2<Integer, double[]> searchActionNode = searchActionNode(dArr5);
                    double[] value2 = searchActionNode.getValue();
                    int intValue2 = searchActionNode.getKey().intValue();
                    if (intValue2 != -1) {
                        if (equals(value2, dArr2)) {
                            this.nodes.get(intValue2).e *= this.QGamma * this.lambda;
                        } else {
                            this.nodes.get(intValue2).e = 0.0d;
                        }
                    }
                }
            }
        }
        return intValue;
    }

    protected boolean equals(double[] dArr, double[] dArr2) {
        return getActionId(dArr) == getActionId(dArr2);
    }

    protected int getActionId(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    public Tuple2<Integer, double[]> searchActionNode(double[] dArr) {
        int compete = compete(computeChoiceValues(this.nodes, dArr, null, null, this.config));
        if (compete == -1) {
            return new Tuple2<>(-1, null);
        }
        return new Tuple2<>(Integer.valueOf(compete), (double[]) this.nodes.get(compete).weight_actions.clone());
    }

    private Tuple2<Integer, Double> searchQNode(double[] dArr, double[] dArr2, QValue qValue) {
        double[] computeChoiceValues = computeChoiceValues(this.nodes, dArr, dArr2, null, this.config);
        double[] dArr3 = {this.config.rho_inputs, this.config.rho_actions, this.config.rho_rewards};
        double[] dummyRewards = dummyRewards();
        int i = -1;
        for (int i2 = 0; i2 < this.nodes.size(); i2++) {
            i = compete(computeChoiceValues);
            if (i == -1) {
                break;
            }
            FalconNode falconNode = this.nodes.get(i);
            if (falconNode.isVigilanceConstraintSatisfied(dArr, dArr2, dummyRewards, dArr3)) {
                computeChoiceValues[i] = -1.0d;
                i = -1;
                dArr3 = falconNode.raiseVigilance(dArr, dArr2, dummyRewards, dArr3, this.config);
            }
        }
        if (i != -1) {
            return new Tuple2<>(Integer.valueOf(i), Double.valueOf(qValue.isValid() ? qValue.getValue() : readQ((double[]) this.nodes.get(i).weight_rewards.clone())));
        }
        return new Tuple2<>(-1, Double.valueOf(qValue.isValid() ? qValue.getValue() : this.config.initialQ));
    }

    public void setTraceUpdateMode(EligibilityTraceUpdateMode eligibilityTraceUpdateMode) {
        this.traceUpdateMode = eligibilityTraceUpdateMode;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }
}
