package org.nd4j.linalg.indexing;

import java.util.ArrayList;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/indexing/Indices.class */
public class Indices {
    public static int[] offsets(NDArrayIndex... nDArrayIndexArr) {
        int[] iArr = new int[nDArrayIndexArr.length];
        for (int i = 0; i < nDArrayIndexArr.length; i++) {
            if (nDArrayIndexArr[i].offset() != 0 || i <= 0 || i >= nDArrayIndexArr.length - 1) {
                iArr[i] = nDArrayIndexArr[i].offset();
            } else {
                iArr[i] = 1;
            }
        }
        return iArr;
    }

    public static NDArrayIndex[] fillIn(int[] iArr, NDArrayIndex... nDArrayIndexArr) {
        if (iArr.length == nDArrayIndexArr.length) {
            return nDArrayIndexArr;
        }
        NDArrayIndex[] nDArrayIndexArr2 = new NDArrayIndex[iArr.length];
        System.arraycopy(nDArrayIndexArr, 0, nDArrayIndexArr2, 0, nDArrayIndexArr.length);
        for (int length = nDArrayIndexArr.length; length < iArr.length; length++) {
            nDArrayIndexArr2[length] = NDArrayIndex.interval(0, iArr[length]);
        }
        return nDArrayIndexArr2;
    }

    public static NDArrayIndex[] adjustIndices(int[] iArr, NDArrayIndex... nDArrayIndexArr) {
        if (nDArrayIndexArr.length < iArr.length) {
            nDArrayIndexArr = fillIn(iArr, nDArrayIndexArr);
        }
        if (nDArrayIndexArr.length > iArr.length) {
            NDArrayIndex[] nDArrayIndexArr2 = new NDArrayIndex[iArr.length];
            System.arraycopy(nDArrayIndexArr, 0, nDArrayIndexArr2, 0, iArr.length);
            return nDArrayIndexArr2;
        }
        if (nDArrayIndexArr.length == iArr.length) {
            return nDArrayIndexArr;
        }
        for (int i = 0; i < nDArrayIndexArr.length; i++) {
            if (nDArrayIndexArr[i].end() >= iArr[i] || (nDArrayIndexArr[i] instanceof NDArrayIndex.NDArrayIndexAll)) {
                nDArrayIndexArr[i] = NDArrayIndex.interval(0, iArr[i] - 1);
            }
        }
        return nDArrayIndexArr;
    }

    public static int[] strides(char c, NDArrayIndex... nDArrayIndexArr) {
        return Nd4j.getStrides(shape(nDArrayIndexArr), c);
    }

    public static int[] shape(NDArrayIndex... nDArrayIndexArr) {
        int[] iArr = new int[nDArrayIndexArr.length];
        for (int i = 0; i < iArr.length; i++) {
            int[] indices = nDArrayIndexArr[i].indices();
            iArr[i] = Math.abs((indices[indices.length - 1] + 1) - indices[0]);
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] > 0) {
                arrayList.add(Integer.valueOf(iArr[i2]));
            }
        }
        return ArrayUtil.toArray(arrayList);
    }

    public static boolean isContiguous(NDArrayIndex... nDArrayIndexArr) {
        return isContiguous(1, nDArrayIndexArr);
    }

    public static boolean isContiguous(int i, NDArrayIndex... nDArrayIndexArr) {
        if (nDArrayIndexArr.length < 1) {
            return true;
        }
        boolean isContiguous = isContiguous(nDArrayIndexArr[0].indices(), i);
        for (int i2 = 1; i2 < nDArrayIndexArr.length; i2++) {
            isContiguous = isContiguous && isContiguous(nDArrayIndexArr[i2].indices(), i);
        }
        return isContiguous;
    }

    public static boolean isContiguous(int[] iArr) {
        return isContiguous(iArr, 1);
    }

    public static boolean isContiguous(int[] iArr, int i) {
        if (iArr.length < 1) {
            return true;
        }
        for (int i2 = 1; i2 < iArr.length; i2++) {
            if (Math.abs(iArr[i2] - iArr[i2 - 1]) > i) {
                return false;
            }
        }
        return true;
    }

    public static NDArrayIndex[] createFromStartAndEnd(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.length() != iNDArray2.length()) {
            throw new IllegalArgumentException("Start length must be equal to end length");
        }
        NDArrayIndex[] nDArrayIndexArr = new NDArrayIndex[iNDArray.length()];
        for (int i = 0; i < nDArrayIndexArr.length; i++) {
            nDArrayIndexArr[i] = NDArrayIndex.interval(iNDArray.getInt(i), iNDArray2.getInt(i));
        }
        return nDArrayIndexArr;
    }

    public static NDArrayIndex[] createFromStartAndEnd(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        if (iNDArray.length() != iNDArray2.length()) {
            throw new IllegalArgumentException("Start length must be equal to end length");
        }
        NDArrayIndex[] nDArrayIndexArr = new NDArrayIndex[iNDArray.length()];
        for (int i = 0; i < nDArrayIndexArr.length; i++) {
            nDArrayIndexArr[i] = NDArrayIndex.interval(iNDArray.getInt(i), iNDArray2.getInt(i), z);
        }
        return nDArrayIndexArr;
    }

    public static int[] shape(int[] iArr, int[] iArr2, NDArrayIndex... nDArrayIndexArr) {
        if (nDArrayIndexArr.length > iArr.length) {
            return iArr;
        }
        int[] iArr3 = new int[nDArrayIndexArr.length];
        if (iArr2.length < iArr.length) {
            int[] iArr4 = new int[iArr.length];
            System.arraycopy(iArr2, 0, iArr4, 0, iArr2.length);
            iArr2 = iArr4;
        }
        for (int i = 0; i < iArr3.length; i++) {
            if (nDArrayIndexArr[i] instanceof NDArrayIndex.NDArrayIndexAll) {
                iArr3[i] = iArr[i];
            } else {
                int[] indices = nDArrayIndexArr[i].indices();
                if (indices.length >= 1) {
                    int i2 = indices[indices.length - 1];
                    if (i2 > iArr[i]) {
                        i2 = iArr[i] - 1;
                    }
                    iArr3[i] = nDArrayIndexArr[i].isInterval() ? Math.abs(i2 - indices[0]) + 1 : nDArrayIndexArr[i].indices().length;
                    int i3 = i;
                    iArr3[i3] = iArr3[i3] - iArr2[i];
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < iArr3.length; i4++) {
            if (iArr3[i4] > 0) {
                arrayList.add(Integer.valueOf(iArr3[i4]));
            }
        }
        return ArrayUtil.toArray(arrayList);
    }

    public static int[] shape(int[] iArr, NDArrayIndex... nDArrayIndexArr) {
        return shape(iArr, new int[iArr.length], nDArrayIndexArr);
    }
}
