package org.deeplearning4j.parallelism.inference.observers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservable.class */
public class BatchedInferenceObservable extends BasicInferenceObservable implements InferenceObservable {
    private static final Logger log = LoggerFactory.getLogger(BatchedInferenceObservable.class);
    private List<INDArray[]> inputs;
    private List<INDArray[]> inputMasks;
    private List<INDArray[]> outputs;
    private AtomicInteger counter;
    private ThreadLocal<Integer> position;
    private List<int[]> outputBatchInputArrays;
    private final Object locker;
    private ReentrantReadWriteLock realLocker;
    private AtomicBoolean isLocked;
    private AtomicBoolean isReadLocked;

    public BatchedInferenceObservable() {
        super(new INDArray[0]);
        this.inputs = new ArrayList();
        this.inputMasks = new ArrayList();
        this.outputs = new ArrayList();
        this.counter = new AtomicInteger(0);
        this.position = new ThreadLocal<>();
        this.outputBatchInputArrays = new ArrayList();
        this.locker = new Object();
        this.realLocker = new ReentrantReadWriteLock();
        this.isLocked = new AtomicBoolean(false);
        this.isReadLocked = new AtomicBoolean(false);
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public void addInput(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        synchronized (this.locker) {
            this.inputs.add(iNDArrayArr);
            this.inputMasks.add(iNDArrayArr2);
            this.position.set(Integer.valueOf(this.counter.getAndIncrement()));
            if (this.isReadLocked.get()) {
                this.realLocker.readLock().unlock();
            }
        }
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public List<Pair<INDArray[], INDArray[]>> getInputBatches() {
        int i;
        this.realLocker.writeLock().lock();
        this.isLocked.set(true);
        this.outputBatchInputArrays.clear();
        if (this.counter.get() <= 1) {
            this.outputBatchInputArrays.add(new int[]{0, 0});
            this.realLocker.writeLock().unlock();
            return Collections.singletonList(new Pair(this.inputs.get(0), this.inputMasks.get(0)));
        }
        ArrayList arrayList = new ArrayList();
        int length = this.inputs.get(0).length;
        for (int i2 = 0; i2 < this.inputs.size(); i2 = i + 1) {
            i = i2;
            for (int i3 = i2 + 1; i3 < this.inputs.size() && canBatch(this.inputs.get(i2), this.inputs.get(i3)); i3++) {
                i = i3;
            }
            int i4 = (i - i2) + 1;
            INDArray[][] iNDArrayArr = new INDArray[i4][0];
            INDArray[][] iNDArrayArr2 = (INDArray[][]) null;
            int i5 = 0;
            for (int i6 = i2; i6 <= i; i6++) {
                iNDArrayArr[i5] = this.inputs.get(i6);
                if (this.inputMasks.get(i6) != null) {
                    if (iNDArrayArr2 == null) {
                        iNDArrayArr2 = new INDArray[i4][0];
                        for (int i7 = 0; i7 < i4; i7++) {
                            iNDArrayArr2[i7] = null;
                        }
                    }
                    iNDArrayArr2[i5] = this.inputMasks.get(i6);
                }
                i5++;
            }
            arrayList.add(DataSetUtil.mergeFeatures(iNDArrayArr, iNDArrayArr2));
            this.outputBatchInputArrays.add(new int[]{i2, i});
        }
        this.realLocker.writeLock().unlock();
        return arrayList;
    }

    private static boolean canBatch(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        for (int i = 0; i < iNDArrayArr.length; i++) {
            if (!Arrays.equals(iNDArrayArr[i].shape(), iNDArrayArr2[i].shape())) {
                return false;
            }
        }
        return true;
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public void setOutputBatches(List<INDArray[]> list) {
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            INDArray[] iNDArrayArr = list.get(i2);
            int[] iArr = this.outputBatchInputArrays.get(i2);
            int i3 = (iArr[1] - iArr[0]) + 1;
            for (int i4 = 0; i4 < i3; i4++) {
                this.outputs.add(new INDArray[iNDArrayArr.length]);
            }
            int i5 = i;
            for (int i6 = 0; i6 < iNDArrayArr.length; i6++) {
                INDArray[] splitExamples = splitExamples(iNDArrayArr[i6], iArr[0], iArr[1]);
                int i7 = i5;
                for (int i8 = 0; i8 < i3; i8++) {
                    int i9 = i7;
                    i7++;
                    this.outputs.get(i9)[i6] = splitExamples[i8];
                    if (i6 == 0) {
                        i++;
                    }
                }
            }
        }
        setChanged();
        notifyObservers();
    }

    private INDArray[] splitExamples(INDArray iNDArray, int i, int i2) {
        int i3 = (i2 - i) + 1;
        if (i3 == 1) {
            return new INDArray[]{iNDArray};
        }
        INDArray[] iNDArrayArr = new INDArray[i3];
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iNDArray.rank()];
        for (int i4 = 1; i4 < iNDArrayIndexArr.length; i4++) {
            iNDArrayIndexArr[i4] = NDArrayIndex.all();
        }
        int i5 = 0;
        for (int i6 = 0; i6 < i3; i6++) {
            long size = this.inputs.get(i + i6)[0].size(0);
            iNDArrayIndexArr[0] = NDArrayIndex.interval(i5, i5 + size);
            iNDArrayArr[i6] = iNDArray.get(iNDArrayIndexArr);
            i5 = (int) (i5 + size);
        }
        return iNDArrayArr;
    }

    protected List<INDArray[]> getOutputs() {
        return this.outputs;
    }

    protected void setCounter(int i) {
        this.counter.set(i);
    }

    public void setPosition(int i) {
        this.position.set(Integer.valueOf(i));
    }

    public int getCounter() {
        return this.counter.get();
    }

    public boolean isLocked() {
        boolean z = (!this.realLocker.readLock().tryLock()) || this.isLocked.get();
        if (!z) {
            this.isReadLocked.set(true);
        }
        return z;
    }

    @Override // org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable, org.deeplearning4j.parallelism.inference.InferenceObservable
    public INDArray[] getOutput() {
        checkOutputException();
        return this.outputs.get(this.position.get().intValue());
    }
}
