package org.nd4j.parameterserver.distributed.messages.aggregations;

import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/nd4j/parameterserver/distributed/messages/aggregations/DotAggregation.class */
public class DotAggregation extends BaseAggregation {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DotAggregation.class);

    protected DotAggregation() {
    }

    public DotAggregation(long j, short s, short s2, INDArray iNDArray) {
        super(j, s, s2);
        this.payload = iNDArray;
        addToChunks(this.payload);
    }

    @Override // org.nd4j.parameterserver.distributed.messages.aggregations.BaseAggregation, org.nd4j.parameterserver.distributed.messages.VoidAggregation
    public INDArray getAccumulatedResult() {
        INDArray accumulatedResult = super.getAccumulatedResult();
        return this.aggregationWidth == 1 ? accumulatedResult : accumulatedResult.isRowVector() ? Nd4j.scalar(accumulatedResult.sumNumber().doubleValue()) : accumulatedResult.sum(1);
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidMessage
    public void processMessage() {
        if (this.chunks == null) {
            this.chunks = new TreeMap();
            this.chunksCounter = new AtomicInteger(1);
            addToChunks(this.payload);
        }
        this.clipboard.pin(this);
        if (this.clipboard.isReady(getOriginatorId(), getTaskId())) {
            this.trainer.aggregationFinished(this.clipboard.unpin(getOriginatorId(), this.taskId));
        }
    }
}
