package org.nd4j.parameterserver.distributed.messages.aggregations;

import java.io.Serializable;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.agrona.concurrent.UnsafeBuffer;
import org.apache.commons.lang3.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/messages/aggregations/BaseAggregation.class */
public abstract class BaseAggregation extends BaseVoidMessage implements VoidAggregation, Serializable {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseAggregation.class);
    protected short aggregationType;
    protected short aggregationWidth;
    protected int numberOfElements;
    protected short shardIndex;
    protected INDArray payload;
    protected transient AtomicInteger chunksCounter;
    protected transient Map<Short, INDArray> chunks;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseAggregation() {
        this.aggregationType = (short) -1;
        this.chunksCounter = new AtomicInteger(1);
        this.chunks = new ConcurrentHashMap();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseAggregation(long j, short s, short s2) {
        this();
        this.aggregationWidth = s;
        this.taskId = j;
        this.shardIndex = s2;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public void setShardIndex(short s) {
        if (s == this.shardIndex) {
            return;
        }
        this.chunks.remove(Short.valueOf(this.shardIndex));
        this.chunks.put(Short.valueOf(s), this.payload);
        this.shardIndex = s;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addToChunks(INDArray iNDArray) {
        this.chunks.put(Short.valueOf(this.shardIndex), iNDArray);
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public void accumulateAggregation(@NonNull VoidAggregation voidAggregation) {
        if (voidAggregation == null) {
            throw new NullPointerException("aggregation");
        }
        if (voidAggregation.getAggregationType() != getAggregationType()) {
            throw new ND4JIllegalStateException("Trying to aggregate different aggregations!");
        }
        if (getShardIndex() == voidAggregation.getShardIndex()) {
            return;
        }
        if (this.chunks.get(Short.valueOf(voidAggregation.getShardIndex())) == null) {
            this.chunksCounter.incrementAndGet();
        }
        this.chunks.put(Short.valueOf(voidAggregation.getShardIndex()), voidAggregation.getPayload());
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public INDArray getAccumulatedResult() {
        return this.aggregationWidth == 1 ? this.chunks.get((short) 0) : Nd4j.hstack(this.chunks.values());
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public int getMissingChunks() {
        return this.aggregationWidth - this.chunksCounter.get();
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage, org.nd4j.parameterserver.distributed.messages.VoidMessage
    public int getMessageType() {
        return 21;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage, org.nd4j.parameterserver.distributed.messages.VoidMessage
    public byte[] asBytes() {
        return SerializationUtils.serialize(this);
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage, org.nd4j.parameterserver.distributed.messages.VoidMessage
    public UnsafeBuffer asUnsafeBuffer() {
        return new UnsafeBuffer(asBytes());
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage, org.nd4j.parameterserver.distributed.messages.VoidMessage
    public short getTargetId() {
        return (short) -1;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public short getAggregationType() {
        return this.aggregationType;
    }

    public void setAggregationType(short s) {
        this.aggregationType = s;
    }

    public short getAggregationWidth() {
        return this.aggregationWidth;
    }

    public void setAggregationWidth(short s) {
        this.aggregationWidth = s;
    }

    public int getNumberOfElements() {
        return this.numberOfElements;
    }

    public void setNumberOfElements(int i) {
        this.numberOfElements = i;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage, org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public short getShardIndex() {
        return this.shardIndex;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public INDArray getPayload() {
        return this.payload;
    }

    public void setPayload(INDArray iNDArray) {
        this.payload = iNDArray;
    }

    public AtomicInteger getChunksCounter() {
        return this.chunksCounter;
    }

    public Map<Short, INDArray> getChunks() {
        return this.chunks;
    }
}
