package com.github.cschen1205.falcon;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:com/github/cschen1205/falcon/TDFalcon.class */
public class TDFalcon extends Falcon {
    public double QEpsilon;
    public double QEpsilonDecay;
    public double minQEpsilon;
    private Random random;
    public double QGamma;
    public double QAlpha;
    public TDMethod method;

    public TDFalcon(FalconConfig falconConfig) {
        super(falconConfig);
        this.QEpsilon = 0.5d;
        this.QEpsilonDecay = 5.0E-4d;
        this.minQEpsilon = 0.005d;
        this.random = new Random();
        this.QGamma = 0.9d;
        this.QAlpha = 0.5d;
        this.method = TDMethod.QLearn;
    }

    public TDFalcon(FalconConfig falconConfig, TDMethod tDMethod) {
        super(falconConfig);
        this.QEpsilon = 0.5d;
        this.QEpsilonDecay = 5.0E-4d;
        this.minQEpsilon = 0.005d;
        this.random = new Random();
        this.QGamma = 0.9d;
        this.QAlpha = 0.5d;
        this.method = TDMethod.QLearn;
        this.method = tDMethod;
    }

    public void decayQEpsilon() {
        if (this.QEpsilon > this.minQEpsilon) {
            this.QEpsilon -= this.QEpsilonDecay;
        }
    }

    private int getRandomActionId(Set<Integer> set) {
        if (set.size() == 0) {
            return -1;
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        return ((Integer) arrayList.get(this.random.nextInt(arrayList.size()))).intValue();
    }

    @Override // com.github.cschen1205.falcon.Falcon
    public int selectActionId(double[] dArr, QValueProvider qValueProvider) {
        return this.random.nextDouble() <= (1.0d - this.QEpsilon) + (this.QEpsilon / ((double) numAction())) ? getActionIdWithMaxQ(dArr, qValueProvider) : getRandomActionId();
    }

    private int getRandomActionId() {
        return this.random.nextInt(numAction());
    }

    private int getActionIdWithMaxQ(double[] dArr, QValueProvider qValueProvider) {
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < numAction(); i2++) {
            double searchQ = searchQ(dArr, i2, qValueProvider != null ? qValueProvider.queryQValue(dArr, i2, true) : QValue.Invalid());
            if (searchQ > d) {
                d = searchQ;
                i = i2;
            }
        }
        return i;
    }

    private int getActionIdWithMaxQ(double[] dArr, Set<Integer> set, QValueProvider qValueProvider) {
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (Integer num : set) {
            double searchQ = searchQ(dArr, num.intValue(), qValueProvider != null ? qValueProvider.queryQValue(dArr, num.intValue(), true) : QValue.Invalid());
            if (searchQ > d) {
                d = searchQ;
                i = num.intValue();
            }
        }
        return i;
    }

    @Override // com.github.cschen1205.falcon.Falcon
    public int selectActionId(double[] dArr, Set<Integer> set, QValueProvider qValueProvider) {
        return this.random.nextDouble() <= (1.0d - this.QEpsilon) + (this.QEpsilon / ((double) set.size())) ? getActionIdWithMaxQ(dArr, set, qValueProvider) : getRandomActionId(set);
    }

    public int learnQ(double[] dArr, int i, double[] dArr2, double d, QValueProvider qValueProvider) {
        return learnQ(dArr, i, dArr2, (Set<Integer>) null, d, qValueProvider);
    }

    public int learnQ(double[] dArr, int i, double[] dArr2, Set<Integer> set, double d, QValueProvider qValueProvider) {
        int numAction = numAction();
        double[] dArr3 = new double[numAction];
        for (int i2 = 0; i2 < numAction; i2++) {
            dArr3[i2] = 0.0d;
        }
        dArr3[i] = 1.0d;
        return learnQ(dArr, dArr3, dArr2, set, d, qValueProvider);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] dummyRewards() {
        int numReward = numReward();
        double[] dArr = new double[numReward];
        for (int i = 0; i < numReward; i++) {
            dArr[i] = 1.0d;
        }
        return dArr;
    }

    public int learnQ(double[] dArr, double[] dArr2, double[] dArr3, Set<Integer> set, double d, QValueProvider qValueProvider) {
        return learn(dArr, dArr2, updateQValue(dArr, dArr2, dArr3, set, d, qValueProvider));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getTDError(double d, double[] dArr, Set<Integer> set, double d2, QValueProvider qValueProvider) {
        double d3 = 0.0d;
        if (this.method == TDMethod.QLearn) {
            d3 = (d2 + (this.QGamma * searchMaxQ(dArr, set, qValueProvider))) - d;
        } else if (this.method == TDMethod.Sarsa) {
            int selectActionId = selectActionId(dArr, set, qValueProvider);
            double d4 = this.config.initialQ;
            if (selectActionId != -1) {
                d4 = searchQ(dArr, selectActionId, qValueProvider != null ? qValueProvider.queryQValue(dArr, selectActionId, true) : QValue.Invalid());
            }
            d3 = (d2 + (this.QGamma * d4)) - d;
        }
        return d3;
    }

    protected double[] updateQValue(double[] dArr, double[] dArr2, double[] dArr3, Set<Integer> set, double d, QValueProvider qValueProvider) {
        double d2 = 0.0d;
        boolean z = false;
        if (qValueProvider != null) {
            QValue queryQValue = qValueProvider.queryQValue(dArr, dArr2, false);
            if (queryQValue.isValid()) {
                d2 = queryQValue.getValue();
                z = true;
            }
        }
        if (!z) {
            double searchQ = searchQ(dArr, dArr2, QValue.Invalid());
            double tDError = this.QAlpha * getTDError(searchQ, dArr3, set, d, qValueProvider);
            if (this.config.isBounded) {
                tDError *= 1.0d - searchQ;
            }
            d2 = searchQ + tDError;
        }
        double clamp = clamp(d2, 0.0d, 1.0d);
        double[] dArr4 = new double[numReward()];
        dArr4[0] = clamp;
        dArr4[1] = 1.0d - clamp;
        return dArr4;
    }

    private double searchMaxQ(double[] dArr, Set<Integer> set, QValueProvider qValueProvider) {
        int numAction = numAction();
        if (set == null) {
            set = new HashSet();
            for (int i = 0; i < numAction; i++) {
                set.add(Integer.valueOf(i));
            }
        }
        boolean z = false;
        double d = Double.NEGATIVE_INFINITY;
        for (Integer num : set) {
            double searchQ = searchQ(dArr, num.intValue(), qValueProvider != null ? qValueProvider.queryQValue(dArr, num.intValue(), true) : QValue.Invalid());
            if (searchQ > d) {
                d = searchQ;
                z = true;
            }
        }
        if (!z) {
            d = this.config.initialQ;
        }
        return d;
    }

    private double searchQ(double[] dArr, int i, QValue qValue) {
        int numAction = numAction();
        double[] dArr2 = new double[numAction];
        for (int i2 = 0; i2 < numAction; i2++) {
            dArr2[i2] = 0.0d;
        }
        dArr2[i] = 1.0d;
        return searchQ(dArr, dArr2, qValue);
    }

    private double searchQ(double[] dArr, double[] dArr2, QValue qValue) {
        if (qValue.isValid()) {
            return qValue.getValue();
        }
        int compete = compete(computeChoiceValues(this.nodes, dArr, dArr2, null, this.config));
        if (compete == -1) {
            return this.config.initialQ;
        }
        FalconNode falconNode = this.nodes.get(compete);
        return readQ(falconNode.fuzzyAND(dummyRewards(), falconNode.weight_rewards));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double readQ(double[] dArr) {
        return dArr[0];
    }

    private static double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }
}
