package org.deeplearning4j.spark.parameterserver.networking.v2;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.v2.transport.UpdatesHandler;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.class */
public class UpdatesConsumer implements UpdatesHandler {
    private static final Logger log = LoggerFactory.getLogger(UpdatesConsumer.class);
    protected int numWorkers;
    protected transient INDArray params;
    protected transient INDArray updates;
    protected transient StepFunction stepFunction;
    protected transient GradientsAccumulator accumulator;
    protected final transient AtomicLong updatesCount = new AtomicLong(0);
    protected final transient AtomicBoolean hasSomething = new AtomicBoolean(false);
    protected final transient AtomicBoolean bypassMode = new AtomicBoolean(false);
    protected final transient AtomicLong denseCounter = new AtomicLong(0);
    protected final transient AtomicLong sparseCounter = new AtomicLong(0);
    protected transient IndexedTail updatesBuffer;

    /* loaded from: input_file:org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer$UpdatesConsumerBuilder.class */
    public static class UpdatesConsumerBuilder {
        private int numWorkers;
        private INDArray params;
        private INDArray updates;
        private StepFunction stepFunction;
        private GradientsAccumulator accumulator;
        private IndexedTail updatesBuffer;

        UpdatesConsumerBuilder() {
        }

        public UpdatesConsumerBuilder numWorkers(int i) {
            this.numWorkers = i;
            return this;
        }

        public UpdatesConsumerBuilder params(INDArray iNDArray) {
            this.params = iNDArray;
            return this;
        }

        public UpdatesConsumerBuilder updates(INDArray iNDArray) {
            this.updates = iNDArray;
            return this;
        }

        public UpdatesConsumerBuilder stepFunction(StepFunction stepFunction) {
            this.stepFunction = stepFunction;
            return this;
        }

        public UpdatesConsumerBuilder accumulator(GradientsAccumulator gradientsAccumulator) {
            this.accumulator = gradientsAccumulator;
            return this;
        }

        public UpdatesConsumerBuilder updatesBuffer(IndexedTail indexedTail) {
            this.updatesBuffer = indexedTail;
            return this;
        }

        public UpdatesConsumer build() {
            return new UpdatesConsumer(this.numWorkers, this.params, this.updates, this.stepFunction, this.accumulator, this.updatesBuffer);
        }

        public String toString() {
            return "UpdatesConsumer.UpdatesConsumerBuilder(numWorkers=" + this.numWorkers + ", params=" + this.params + ", updates=" + this.updates + ", stepFunction=" + this.stepFunction + ", accumulator=" + this.accumulator + ", updatesBuffer=" + this.updatesBuffer + ")";
        }
    }

    public void onSubscribe(Subscription subscription) {
    }

    public void bypassMode(boolean z) {
        this.bypassMode.set(z);
    }

    public boolean isBypassMod() {
        return this.bypassMode.get();
    }

    public IndexedTail getUpdatesQueue() {
        if (this.updatesBuffer == null && this.accumulator != null) {
            synchronized (this) {
                if (this.updatesBuffer == null) {
                    this.updatesBuffer = new IndexedTail(this.numWorkers, true, this.params.shape());
                }
            }
        }
        return this.updatesBuffer;
    }

    public void onNext(INDArray iNDArray) {
        if (this.updatesBuffer == null && this.accumulator != null) {
            synchronized (this) {
                if (this.updatesBuffer == null) {
                    this.updatesBuffer = new IndexedTail(this.numWorkers, true, this.params.shape());
                }
            }
        }
        if (this.bypassMode.get()) {
            return;
        }
        if (this.accumulator != null) {
            try {
                this.updatesBuffer.put(iNDArray);
                return;
            } catch (Exception e) {
                log.error("", e);
                throw new RuntimeException(e);
            }
        }
        if (this.params == null || this.stepFunction == null) {
            throw new ND4JIllegalStateException("Accumulator & StepFunction is null at the same time");
        }
        synchronized (this) {
            int i = iNDArray.data().getInt(3L);
            if (i == 0) {
                Nd4j.getExecutioner().thresholdDecode(iNDArray, this.updates);
                this.sparseCounter.incrementAndGet();
            } else {
                if (i != 1) {
                    throw new DL4JInvalidConfigException("Unknown compression header received: " + i);
                }
                Nd4j.getExecutioner().bitmapDecode(iNDArray, this.updates);
                this.denseCounter.incrementAndGet();
            }
            this.hasSomething.set(true);
            if (this.updatesCount.incrementAndGet() % 32 == 0) {
                flush();
            }
        }
    }

    public void flush() {
        synchronized (this) {
            if (this.params != null && this.updates != null && this.hasSomething.get()) {
                this.stepFunction.step(this.params, this.updates);
                Nd4j.getExecutioner().commit();
                log.debug("Applying updates. Current ratio: [{}]; Sparse: [{}]; Dense: [{}];", new Object[]{Double.valueOf(this.sparseCounter.get() / this.denseCounter.get()), Long.valueOf(this.sparseCounter.get()), Long.valueOf(this.denseCounter.get())});
                Nd4j.getMemoryManager().memset(this.updates);
                this.hasSomething.set(false);
            }
        }
    }

    public void onError(Throwable th) {
        throw new RuntimeException(th);
    }

    public void onComplete() {
    }

    public INDArray getParametersArray() {
        INDArray dup;
        synchronized (this) {
            dup = this.params.dup(this.params.ordering());
        }
        return dup;
    }

    public static UpdatesConsumerBuilder builder() {
        return new UpdatesConsumerBuilder();
    }

    public UpdatesConsumer(int i, INDArray iNDArray, INDArray iNDArray2, StepFunction stepFunction, GradientsAccumulator gradientsAccumulator, IndexedTail indexedTail) {
        this.numWorkers = i;
        this.params = iNDArray;
        this.updates = iNDArray2;
        this.stepFunction = stepFunction;
        this.accumulator = gradientsAccumulator;
        this.updatesBuffer = indexedTail;
    }

    public UpdatesConsumer() {
    }
}
