package org.nd4j.linalg.indexing;

import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Collections;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.LinearIndex;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/indexing/Indices.class */
public class Indices {
    public static int rowNumber(int i, INDArray iNDArray) {
        int floor = (int) Math.floor(i / iNDArray.size(-1));
        int vectorsAlongDimension = iNDArray.vectorsAlongDimension(-1);
        return floor >= vectorsAlongDimension ? vectorsAlongDimension - 1 : floor;
    }

    public static int linearOffset(int i, INDArray iNDArray) {
        if (iNDArray.ordering() == 'c') {
            int floor = (int) Math.floor(i % iNDArray.size(-1));
            iNDArray.vectorAlongDimension(floor, -1);
            return iNDArray.vectorAlongDimension(floor, -1).offset() + i;
        }
        int stride = iNDArray.stride(-2);
        iNDArray.vectorsAlongDimension(-1);
        return iNDArray.vectorAlongDimension((int) Math.floor((i * stride) / iNDArray.length()), -1).linearIndex(i % iNDArray.size(-1));
    }

    public static int[] linearIndices(INDArray iNDArray) {
        LinearIndex linearIndex = new LinearIndex(iNDArray, iNDArray.dup(), true);
        Nd4j.getExecutioner().iterateOverAllRows(linearIndex);
        return linearIndex.getIndices();
    }

    public static int[] offsets(int[] iArr, INDArrayIndex... iNDArrayIndexArr) {
        int[] iArr2 = new int[iArr.length];
        if (iNDArrayIndexArr.length == iArr.length) {
            for (int i = 0; i < iNDArrayIndexArr.length; i++) {
                if (iNDArrayIndexArr[i] instanceof NDArrayIndexEmpty) {
                    iArr2[i] = 0;
                } else {
                    iArr2[i] = iNDArrayIndexArr[i].offset();
                }
            }
            if (iArr2.length == 1) {
                iArr2 = new int[]{iArr2[0], 0};
            }
        } else {
            if (NDArrayIndex.numPoints(iNDArrayIndexArr) > 0) {
                ArrayList arrayList = new ArrayList();
                for (int i2 = 0; i2 < iNDArrayIndexArr.length; i2++) {
                    if (iNDArrayIndexArr[i2].offset() > 0) {
                        arrayList.add(Integer.valueOf(iNDArrayIndexArr[i2].offset()));
                    }
                }
                if (arrayList.size() > iArr.length) {
                    throw new IllegalStateException("Non zeros greater than shape unable to continue");
                }
                for (int i3 = 0; i3 < arrayList.size(); i3++) {
                    iArr2[i3] = ((Integer) arrayList.get(i3)).intValue();
                }
            } else {
                int i4 = 0;
                for (int i5 = 0; i5 < iNDArrayIndexArr.length; i5++) {
                    if (iNDArrayIndexArr[i5] instanceof NDArrayIndexEmpty) {
                        iArr2[i5] = 0;
                    } else {
                        int i6 = i4;
                        i4++;
                        iArr2[i5] = iNDArrayIndexArr[i6].offset();
                    }
                }
            }
            if (iArr2.length == 1) {
                iArr2 = new int[]{iArr2[0], 0};
            }
        }
        return iArr2;
    }

    public static INDArrayIndex[] fillIn(int[] iArr, INDArrayIndex... iNDArrayIndexArr) {
        if (iArr.length == iNDArrayIndexArr.length) {
            return iNDArrayIndexArr;
        }
        INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[iArr.length];
        System.arraycopy(iNDArrayIndexArr, 0, iNDArrayIndexArr2, 0, iNDArrayIndexArr.length);
        for (int length = iNDArrayIndexArr.length; length < iArr.length; length++) {
            iNDArrayIndexArr2[length] = NDArrayIndex.interval(0, iArr[length]);
        }
        return iNDArrayIndexArr2;
    }

    public static INDArrayIndex[] adjustIndices(int[] iArr, INDArrayIndex... iNDArrayIndexArr) {
        if (Shape.isVector(iArr) && iNDArrayIndexArr.length == 1) {
            return iNDArrayIndexArr;
        }
        if (iNDArrayIndexArr.length < iArr.length) {
            iNDArrayIndexArr = fillIn(iArr, iNDArrayIndexArr);
        }
        if (iNDArrayIndexArr.length > iArr.length) {
            INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[iArr.length];
            System.arraycopy(iNDArrayIndexArr, 0, iNDArrayIndexArr2, 0, iArr.length);
            return iNDArrayIndexArr2;
        }
        if (iNDArrayIndexArr.length == iArr.length) {
            return iNDArrayIndexArr;
        }
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            if (iNDArrayIndexArr[i].end() >= iArr[i] || (iNDArrayIndexArr[i] instanceof NDArrayIndexAll)) {
                iNDArrayIndexArr[i] = NDArrayIndex.interval(0, iArr[i] - 1);
            }
        }
        return iNDArrayIndexArr;
    }

    public static int[] strides(char c, NDArrayIndex... nDArrayIndexArr) {
        return Nd4j.getStrides(shape(nDArrayIndexArr), c);
    }

    public static int[] shape(INDArrayIndex... iNDArrayIndexArr) {
        int[] iArr = new int[iNDArrayIndexArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = iNDArrayIndexArr[i].length();
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] > 0) {
                arrayList.add(Integer.valueOf(iArr[i2]));
            }
        }
        return ArrayUtil.toArray(arrayList);
    }

    public static boolean isContiguous(int[] iArr, int i) {
        if (iArr.length < 1) {
            return true;
        }
        for (int i2 = 1; i2 < iArr.length; i2++) {
            if (Math.abs(iArr[i2] - iArr[i2 - 1]) > i) {
                return false;
            }
        }
        return true;
    }

    public static INDArrayIndex[] createFromStartAndEnd(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.length() != iNDArray2.length()) {
            throw new IllegalArgumentException("Start length must be equal to end length");
        }
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iNDArray.length()];
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            iNDArrayIndexArr[i] = NDArrayIndex.interval(iNDArray.getInt(i), iNDArray2.getInt(i));
        }
        return iNDArrayIndexArr;
    }

    public static INDArrayIndex[] createFromStartAndEnd(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        if (iNDArray.length() != iNDArray2.length()) {
            throw new IllegalArgumentException("Start length must be equal to end length");
        }
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iNDArray.length()];
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            iNDArrayIndexArr[i] = NDArrayIndex.interval(iNDArray.getInt(i), iNDArray2.getInt(i), z);
        }
        return iNDArrayIndexArr;
    }

    public static int[] shape(int[] iArr, INDArrayIndex... iNDArrayIndexArr) {
        int i = 0;
        boolean z = false;
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < iNDArrayIndexArr.length; i3++) {
            INDArrayIndex iNDArrayIndex = iNDArrayIndexArr[i3];
            if (iNDArrayIndex instanceof NDArrayIndexAll) {
                z = true;
            }
            if (iNDArrayIndex instanceof PointIndex) {
                i2++;
            } else if (iNDArrayIndex instanceof NewAxis) {
                if (z) {
                    arrayList2.add(Integer.valueOf(i3));
                } else {
                    i++;
                }
            } else if ((!(iNDArrayIndex instanceof IntervalIndex) || (iNDArrayIndex instanceof NDArrayIndexAll)) && !(iNDArrayIndex instanceof SpecifiedIndex)) {
                arrayList.add(Integer.valueOf(iArr[i2]));
                i2++;
            } else {
                arrayList.add(Integer.valueOf(iNDArrayIndex.length()));
                i2++;
            }
        }
        while (i2 < iArr.length) {
            int i4 = i2;
            i2++;
            arrayList.add(Integer.valueOf(iArr[i4]));
        }
        while (arrayList.size() < 2) {
            arrayList.add(1);
        }
        if (iNDArrayIndexArr.length == 1 && (iNDArrayIndexArr[0] instanceof PointIndex) && iArr.length == 2) {
            Collections.reverse(arrayList);
        }
        if (i > 0) {
            for (int i5 = 0; i5 < i; i5++) {
                arrayList.add(0, 1);
            }
        }
        for (int i6 = 0; i6 < arrayList2.size(); i6++) {
            arrayList.add(((Integer) arrayList2.get(i6)).intValue() - i6, 1);
        }
        return Ints.toArray(arrayList);
    }

    public static int[] stride(INDArray iNDArray, INDArrayIndex[] iNDArrayIndexArr, int... iArr) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < iNDArrayIndexArr.length; i2++) {
            if (iNDArrayIndexArr[i2] instanceof PointIndex) {
                i++;
            } else if (iNDArrayIndexArr[i2] instanceof NewAxis) {
            }
        }
        for (int i3 = 0; i3 < arrayList2.size(); i3++) {
            arrayList.add(((Integer) arrayList2.get(i3)).intValue() - i3, 1);
        }
        return Ints.toArray(arrayList);
    }

    public static boolean isScalar(INDArray iNDArray, INDArrayIndex... iNDArrayIndexArr) {
        boolean z = true;
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            z = z && iNDArrayIndex.length() == 1;
        }
        int numNewAxis = NDArrayIndex.numNewAxis(iNDArrayIndexArr);
        if (z && numNewAxis == 0 && iNDArrayIndexArr.length == iNDArray.rank()) {
            return true;
        }
        return (z && iNDArrayIndexArr.length == iNDArray.rank() - numNewAxis) ? z : z;
    }
}
