package org.nd4j.autodiff.listeners;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.apache.camel.util.URISupport;
import org.nd4j.base.Preconditions;

/* loaded from: input_file:org/nd4j/autodiff/listeners/Loss.class */
public class Loss {
    private final List<String> lossNames;
    private final double[] losses;

    public Loss(@NonNull List<String> list, @NonNull double[] dArr) {
        if (list == null) {
            throw new NullPointerException("lossNames is marked @NonNull but is null");
        }
        if (dArr == null) {
            throw new NullPointerException("losses is marked @NonNull but is null");
        }
        Preconditions.checkState(list.size() == dArr.length, "Expected equal number of loss names and loss values");
        this.lossNames = list;
        this.losses = dArr;
    }

    public int numLosses() {
        return this.lossNames.size();
    }

    public List<String> lossNames() {
        return this.lossNames;
    }

    public double[] lossValues() {
        return this.losses;
    }

    public double getLoss(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("lossName is marked @NonNull but is null");
        }
        int indexOf = this.lossNames.indexOf(str);
        Preconditions.checkState(indexOf >= 0, "No loss with name \"%s\" exists. All loss names: %s", str, this.lossNames);
        return this.losses[indexOf];
    }

    public double totalLoss() {
        double d = 0.0d;
        for (double d2 : this.losses) {
            d += d2;
        }
        return d;
    }

    public Loss copy() {
        return new Loss(this.lossNames, this.losses);
    }

    public static Loss sum(List<Loss> list) {
        if (list.size() == 0) {
            return new Loss(Collections.emptyList(), new double[0]);
        }
        double[] dArr = new double[list.get(0).losses.length];
        ArrayList arrayList = new ArrayList(list.get(0).lossNames);
        for (int i = 0; i < list.size(); i++) {
            Loss loss = list.get(i);
            Preconditions.checkState(loss.losses.length == dArr.length, "Loss %s has %s losses, the others before it had %s.", i, loss.losses.length, dArr.length);
            Preconditions.checkState(loss.lossNames.equals(arrayList), "Loss %s has different loss names from the others before it.  Expected %s, got %s.", Integer.valueOf(i), arrayList, loss.lossNames);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + loss.losses[i2];
            }
        }
        return new Loss(arrayList, dArr);
    }

    public static Loss average(List<Loss> list) {
        Loss sum = sum(list);
        for (int i = 0; i < sum.losses.length; i++) {
            double[] dArr = sum.losses;
            int i2 = i;
            dArr[i2] = dArr[i2] / list.size();
        }
        return sum;
    }

    public static Loss add(Loss loss, Loss loss2) {
        Preconditions.checkState(loss.lossNames.equals(loss2.lossNames), "Loss names differ.  First loss has names %s, second has names %s.", loss.lossNames, loss2.lossNames);
        double[] dArr = new double[loss.losses.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = loss.losses[i] + loss2.losses[i];
        }
        return new Loss(loss.lossNames, dArr);
    }

    public static Loss sub(Loss loss, Loss loss2) {
        Preconditions.checkState(loss.lossNames.equals(loss2.lossNames), "Loss names differ.  First loss has names %s, second has names %s.", loss.lossNames, loss2.lossNames);
        double[] dArr = new double[loss.losses.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = loss.losses[i] - loss2.losses[i];
        }
        return new Loss(loss.lossNames, dArr);
    }

    public static Loss div(Loss loss, Number number) {
        double[] dArr = new double[loss.losses.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = loss.losses[i] / number.doubleValue();
        }
        return new Loss(loss.lossNames, dArr);
    }

    public Loss add(Loss loss) {
        return add(this, loss);
    }

    public Loss sub(Loss loss) {
        return sub(this, loss);
    }

    public Loss plus(Loss loss) {
        return add(this, loss);
    }

    public Loss minus(Loss loss) {
        return sub(this, loss);
    }

    public Loss div(Number number) {
        return div(this, number);
    }

    public List<String> getLossNames() {
        return this.lossNames;
    }

    public double[] getLosses() {
        return this.losses;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Loss)) {
            return false;
        }
        Loss loss = (Loss) obj;
        if (!loss.canEqual(this)) {
            return false;
        }
        List<String> lossNames = getLossNames();
        List<String> lossNames2 = loss.getLossNames();
        if (lossNames == null) {
            if (lossNames2 != null) {
                return false;
            }
        } else if (!lossNames.equals(lossNames2)) {
            return false;
        }
        return Arrays.equals(getLosses(), loss.getLosses());
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Loss;
    }

    public int hashCode() {
        List<String> lossNames = getLossNames();
        return (((1 * 59) + (lossNames == null ? 43 : lossNames.hashCode())) * 59) + Arrays.hashCode(getLosses());
    }

    public String toString() {
        return "Loss(lossNames=" + getLossNames() + ", losses=" + Arrays.toString(getLosses()) + URISupport.RAW_TOKEN_END;
    }
}
