package org.datavec.python;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

/* loaded from: input_file:org/datavec/python/NumpyArray.class */
public class NumpyArray {
    private static NativeOps nativeOps;
    private static Map<String, INDArray> arrayCache;
    private long address;
    private long[] shape;
    private long[] strides;
    private DataType dtype;
    private INDArray nd4jArray;

    /* loaded from: input_file:org/datavec/python/NumpyArray$NumpyArrayBuilder.class */
    public static class NumpyArrayBuilder {
        private long address;
        private long[] shape;
        private long[] strides;
        private DataType dtype;
        private boolean copy;

        NumpyArrayBuilder() {
        }

        public NumpyArrayBuilder address(long j) {
            this.address = j;
            return this;
        }

        public NumpyArrayBuilder shape(long[] jArr) {
            this.shape = jArr;
            return this;
        }

        public NumpyArrayBuilder strides(long[] jArr) {
            this.strides = jArr;
            return this;
        }

        public NumpyArrayBuilder dtype(DataType dataType) {
            this.dtype = dataType;
            return this;
        }

        public NumpyArrayBuilder copy(boolean z) {
            this.copy = z;
            return this;
        }

        public NumpyArray build() {
            return new NumpyArray(this.address, this.shape, this.strides, this.dtype, this.copy);
        }

        public String toString() {
            return "NumpyArray.NumpyArrayBuilder(address=" + this.address + ", shape=" + Arrays.toString(this.shape) + ", strides=" + Arrays.toString(this.strides) + ", dtype=" + this.dtype + ", copy=" + this.copy + ")";
        }
    }

    public NumpyArray(long j, long[] jArr, long[] jArr2, DataType dataType, boolean z) {
        this.address = j;
        this.shape = jArr;
        this.strides = jArr2;
        this.dtype = dataType;
        setND4JArray();
        if (z) {
            this.nd4jArray = this.nd4jArray.dup();
            Nd4j.getAffinityManager().ensureLocation(this.nd4jArray, AffinityManager.Location.HOST);
            this.address = this.nd4jArray.data().address();
        }
    }

    public NumpyArray copy() {
        return new NumpyArray(this.nd4jArray.dup());
    }

    public NumpyArray(long j, long[] jArr, long[] jArr2) {
        this(j, jArr, jArr2, DataType.FLOAT, false);
    }

    public NumpyArray(long j, long[] jArr, long[] jArr2, DataType dataType) {
        this(j, jArr, jArr2, dataType, false);
    }

    private void setND4JArray() {
        long j = 1;
        for (long j2 : this.shape) {
            j *= j2;
        }
        String str = this.address + "_" + j + "_" + this.dtype + "_" + ArrayUtils.toString(this.strides);
        this.nd4jArray = arrayCache.get(str);
        if (this.nd4jArray == null) {
            DataBuffer createBuffer = Nd4j.createBuffer(nativeOps.pointerForAddress(this.address).limit(j).capacity(j), j, this.dtype);
            int elementSize = createBuffer.getElementSize();
            long[] jArr = new long[this.strides.length];
            for (int i = 0; i < this.strides.length; i++) {
                jArr[i] = this.strides[i] / elementSize;
            }
            this.nd4jArray = Nd4j.create(createBuffer, this.shape, jArr, 0L, Shape.getOrder(this.shape, jArr, 1L), this.dtype);
            arrayCache.put(str, this.nd4jArray);
        } else if (!Arrays.equals(this.nd4jArray.shape(), this.shape)) {
            this.nd4jArray = this.nd4jArray.reshape(this.shape);
        }
        Nd4j.getAffinityManager().ensureLocation(this.nd4jArray, AffinityManager.Location.HOST);
    }

    public INDArray getNd4jArray() {
        Nd4j.getAffinityManager().tagLocation(this.nd4jArray, AffinityManager.Location.HOST);
        return this.nd4jArray;
    }

    public NumpyArray(INDArray iNDArray) {
        Nd4j.getAffinityManager().ensureLocation(iNDArray, AffinityManager.Location.HOST);
        DataBuffer data = iNDArray.data();
        this.address = data.pointer().address();
        this.shape = iNDArray.shape();
        long[] stride = iNDArray.stride();
        this.strides = new long[stride.length];
        int elementSize = data.getElementSize();
        for (int i = 0; i < this.strides.length; i++) {
            this.strides[i] = stride[i] * elementSize;
        }
        this.dtype = iNDArray.dataType();
        this.nd4jArray = iNDArray;
        arrayCache.put(this.address + "_" + iNDArray.length() + "_" + this.dtype + "_" + ArrayUtils.toString(this.strides), iNDArray);
    }

    public static NumpyArrayBuilder builder() {
        return new NumpyArrayBuilder();
    }

    public long getAddress() {
        return this.address;
    }

    public long[] getShape() {
        return this.shape;
    }

    public long[] getStrides() {
        return this.strides;
    }

    public DataType getDtype() {
        return this.dtype;
    }

    public NumpyArray() {
    }

    static {
        Nd4j.scalar(1.0d);
        nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        arrayCache = new HashMap();
    }
}
