package org.nd4j.parameterserver;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.aeron.ipc.NDArrayCallback;
import org.nd4j.aeron.ipc.NDArrayHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/parameterserver/ParameterServerListener.class */
public class ParameterServerListener implements NDArrayCallback, NDArrayHolder {
    private INDArray arr;
    private AtomicInteger totalN = new AtomicInteger(0);
    private boolean master;
    private int[] shape;

    public ParameterServerListener(int[] iArr) {
        this.arr = Nd4j.create(iArr);
    }

    public synchronized void onNDArrayPartial(INDArray iNDArray, long j, int... iArr) {
        this.arr.tensorAlongDimension((int) j, iArr).addi(iNDArray);
        this.totalN.incrementAndGet();
    }

    public synchronized void onNDArray(INDArray iNDArray) {
        if (this.shape == null) {
            this.arr.addi(iNDArray.reshape(1, iNDArray.length()));
        } else {
            this.arr.addi(iNDArray);
        }
        this.totalN.incrementAndGet();
    }

    public synchronized void finish() {
        this.arr.divi(this.totalN);
    }

    public int totalUpdates() {
        return this.totalN.get();
    }

    public synchronized INDArray get() {
        return this.arr;
    }

    public synchronized INDArray getTad(int i, int... iArr) {
        return this.arr.tensorAlongDimension(i, iArr);
    }

    public INDArray getArr() {
        return this.arr;
    }

    public AtomicInteger getTotalN() {
        return this.totalN;
    }

    public boolean isMaster() {
        return this.master;
    }

    public int[] getShape() {
        return this.shape;
    }

    public void setArr(INDArray iNDArray) {
        this.arr = iNDArray;
    }

    public void setTotalN(AtomicInteger atomicInteger) {
        this.totalN = atomicInteger;
    }

    public void setMaster(boolean z) {
        this.master = z;
    }

    public void setShape(int[] iArr) {
        this.shape = iArr;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterServerListener)) {
            return false;
        }
        ParameterServerListener parameterServerListener = (ParameterServerListener) obj;
        if (!parameterServerListener.canEqual(this)) {
            return false;
        }
        INDArray arr = getArr();
        INDArray arr2 = parameterServerListener.getArr();
        if (arr == null) {
            if (arr2 != null) {
                return false;
            }
        } else if (!arr.equals(arr2)) {
            return false;
        }
        AtomicInteger totalN = getTotalN();
        AtomicInteger totalN2 = parameterServerListener.getTotalN();
        if (totalN == null) {
            if (totalN2 != null) {
                return false;
            }
        } else if (!totalN.equals(totalN2)) {
            return false;
        }
        return isMaster() == parameterServerListener.isMaster() && Arrays.equals(getShape(), parameterServerListener.getShape());
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterServerListener;
    }

    public int hashCode() {
        INDArray arr = getArr();
        int hashCode = (1 * 59) + (arr == null ? 43 : arr.hashCode());
        AtomicInteger totalN = getTotalN();
        return (((((hashCode * 59) + (totalN == null ? 43 : totalN.hashCode())) * 59) + (isMaster() ? 79 : 97)) * 59) + Arrays.hashCode(getShape());
    }

    public String toString() {
        return "ParameterServerListener(arr=" + getArr() + ", totalN=" + getTotalN() + ", master=" + isMaster() + ", shape=" + Arrays.toString(getShape()) + ")";
    }
}
