package org.nd4j.parameterserver.distributed.v2.chunks.impl;

import java.io.ByteArrayInputStream;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.util.SerializationUtils;
import org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker;
import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk;
import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/chunks/impl/InmemoryChunksTracker.class */
public class InmemoryChunksTracker<T extends VoidMessage> implements ChunksTracker<T> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InmemoryChunksTracker.class);
    private final String originId;
    private final int numChunks;
    private Map<Integer, AtomicBoolean> map = new ConcurrentHashMap();
    private final byte[] buffer;
    private final long size;

    public InmemoryChunksTracker(VoidChunk voidChunk) {
        this.originId = voidChunk.getOriginalId();
        this.numChunks = voidChunk.getNumberOfChunks();
        if (voidChunk.getTotalSize() > 2147483647L) {
            throw new ND4JIllegalStateException("Total message size > Integer.MAX_VALUE");
        }
        this.size = voidChunk.getTotalSize();
        try {
            this.buffer = new byte[(int) this.size];
            for (int i = 0; i < this.numChunks; i++) {
                this.map.put(Integer.valueOf(i), new AtomicBoolean(false));
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker
    public long size() {
        return this.size;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker
    public boolean isComplete() {
        Iterator<AtomicBoolean> it = this.map.values().iterator();
        while (it.hasNext()) {
            if (!it.next().get()) {
                return false;
            }
        }
        return true;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker
    public synchronized boolean append(@NonNull VoidChunk voidChunk) {
        if (voidChunk == null) {
            throw new NullPointerException("chunk is marked @NonNull but is null");
        }
        AtomicBoolean atomicBoolean = this.map.get(Integer.valueOf(voidChunk.getChunkId()));
        if (atomicBoolean.get()) {
            return isComplete();
        }
        int chunkId = voidChunk.getChunkId() * voidChunk.getSplitSize();
        int i = 0;
        for (int i2 = chunkId; i2 < chunkId + voidChunk.getPayload().length; i2++) {
            int i3 = i;
            i++;
            this.buffer[i2] = voidChunk.getPayload()[i3];
        }
        atomicBoolean.set(true);
        return isComplete();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker
    public T getMessage() {
        if (!isComplete()) {
            throw new ND4JIllegalStateException("Message isn't ready for concatenation");
        }
        try {
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(this.buffer);
            Throwable th = null;
            try {
                T t = (T) SerializationUtils.deserialize(byteArrayInputStream);
                if (byteArrayInputStream != null) {
                    if (0 != 0) {
                        try {
                            byteArrayInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        byteArrayInputStream.close();
                    }
                }
                return t;
            } finally {
            }
        } catch (Exception e) {
            log.error("Exception: {}", (Throwable) e);
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker
    public void release() {
    }

    @Override // org.nd4j.parameterserver.distributed.v2.chunks.ChunksTracker
    public String getOriginId() {
        return this.originId;
    }
}
