package org.nd4j.linalg.indexing;

import com.google.common.primitives.Longs;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/indexing/ShapeOffsetResolution.class */
public class ShapeOffsetResolution implements Serializable {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ShapeOffsetResolution.class);
    private INDArray arr;
    private int[] fixed;
    private int[] prependAxis;
    private long[] offsets;
    private long[] shapes;
    private long[] strides;
    private long offset = -1;

    public ShapeOffsetResolution(INDArray iNDArray) {
        this.arr = iNDArray;
    }

    public boolean tryShortCircuit(INDArrayIndex... iNDArrayIndexArr) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < iNDArrayIndexArr.length; i6++) {
            if (iNDArrayIndexArr[i6] instanceof PointIndex) {
                i++;
            }
            if (iNDArrayIndexArr[i6] instanceof SpecifiedIndex) {
                i5++;
            } else if ((iNDArrayIndexArr[i6] instanceof IntervalIndex) && !(iNDArrayIndexArr[i6] instanceof NDArrayIndexAll)) {
                i2++;
            } else if (iNDArrayIndexArr[i6] instanceof NewAxis) {
                i3++;
            } else if (iNDArrayIndexArr[i6] instanceof NDArrayIndexAll) {
                i4++;
            }
        }
        if (this.arr.isVector()) {
            if ((iNDArrayIndexArr[0] instanceof NDArrayIndexAll) && iNDArrayIndexArr.length == 1) {
                this.offset = 0L;
                this.shapes = this.arr.shape();
                this.strides = this.arr.stride();
                this.offsets = new long[this.arr.rank()];
                return true;
            }
            if ((iNDArrayIndexArr[0] instanceof PointIndex) && (iNDArrayIndexArr[1] instanceof NDArrayIndexAll)) {
                this.shapes = new long[2];
                this.strides = new long[2];
                for (int i7 = 0; i7 < 2; i7++) {
                    this.shapes[i7] = 1;
                    this.strides[i7] = this.arr.stride(i7);
                }
                this.offsets = new long[this.arr.rank()];
                if (this.arr.isRowVector()) {
                    this.offset = iNDArrayIndexArr[0].offset() * this.strides[1];
                    return true;
                }
                this.offset = iNDArrayIndexArr[0].offset() * this.strides[0];
                return true;
            }
            if ((iNDArrayIndexArr[0] instanceof PointIndex) && iNDArrayIndexArr.length == 1) {
                this.shapes = new long[2];
                this.strides = new long[2];
                for (int i8 = 0; i8 < 2; i8++) {
                    this.shapes[i8] = 1;
                    this.strides[i8] = this.arr.stride(i8);
                }
                if (this.arr.isRowVector()) {
                    this.offset = iNDArrayIndexArr[0].offset() * this.strides[1];
                    return true;
                }
                this.offset = iNDArrayIndexArr[0].offset() * this.strides[0];
                return true;
            }
            if (this.arr.isRowVector()) {
                if (iNDArrayIndexArr[0] instanceof PointIndex) {
                    if (iNDArrayIndexArr.length > 1 && (iNDArrayIndexArr[1] instanceof IntervalIndex)) {
                        this.offset = iNDArrayIndexArr[1].offset();
                        this.shapes = new long[2];
                        this.shapes[0] = 1;
                        this.shapes[1] = iNDArrayIndexArr[1].length();
                        this.strides = new long[2];
                        this.strides[0] = 0;
                        this.strides[1] = iNDArrayIndexArr[1].stride();
                        this.offsets = new long[2];
                        return true;
                    }
                } else if (!(iNDArrayIndexArr[0] instanceof IntervalIndex)) {
                    return false;
                }
            } else if (iNDArrayIndexArr.length <= 1 || !(iNDArrayIndexArr[1] instanceof PointIndex)) {
                if (!(iNDArrayIndexArr[0] instanceof IntervalIndex)) {
                    return false;
                }
            } else if (iNDArrayIndexArr[0] instanceof IntervalIndex) {
                this.offset = iNDArrayIndexArr[0].offset();
                this.shapes = new long[2];
                this.shapes[1] = 1;
                this.shapes[0] = iNDArrayIndexArr[1].length();
                this.strides = new long[2];
                this.strides[1] = 0;
                this.strides[0] = iNDArrayIndexArr[1].stride();
                this.offsets = new long[2];
                return true;
            }
        }
        if (i5 > 0 && i2 < 1 && i3 < 1 && i4 > 0 && i < 1 && this.arr.rank() == 2) {
            this.shapes = new long[this.arr.rank()];
            this.strides = new long[this.arr.rank()];
            this.offsets = new long[this.arr.rank()];
            this.offset = 0L;
            boolean z = true;
            for (int i9 = 0; i9 < 2; i9++) {
                z = z && (iNDArrayIndexArr[i9] instanceof SpecifiedIndex);
            }
            for (int i10 = 0; i10 < this.arr.rank(); i10++) {
                if (iNDArrayIndexArr[i10] instanceof SpecifiedIndex) {
                    SpecifiedIndex specifiedIndex = (SpecifiedIndex) iNDArrayIndexArr[i10];
                    if (specifiedIndex.getIndexes().length >= this.arr.rank()) {
                        return false;
                    }
                    this.shapes[i10] = iNDArrayIndexArr[i10].length();
                    this.offsets[i10] = iNDArrayIndexArr[i10].offset();
                    if (!z || (i10 == 0 && z)) {
                        this.offset = this.offsets[i10] * this.arr.stride(i10);
                    }
                    if (iNDArrayIndexArr[i10].length() != 1) {
                        this.strides[i10] = this.arr.stride(i10) * specifiedIndex.getIndexes()[i10];
                    } else {
                        this.strides[i10] = 1;
                    }
                } else {
                    if (!(iNDArrayIndexArr[i10] instanceof NDArrayIndexAll)) {
                        throw new IllegalArgumentException("Illegal opType of index " + iNDArrayIndexArr[i10].getClass().getName());
                    }
                    this.shapes[i10] = this.arr.size(i10);
                    this.strides[i10] = this.arr.tensorAlongDimension(0, i10).elementWiseStride();
                }
            }
            return true;
        }
        if (i5 < 1 && i2 < 1 && i3 < 1 && i > 0 && i4 > 0) {
            int max = Math.max(this.arr.rank() - i, 2);
            long[] jArr = new long[max];
            Arrays.fill(jArr, 1L);
            long[] jArr2 = new long[max];
            Arrays.fill(jArr2, this.arr.elementStride());
            long[] jArr3 = new long[max];
            long j = 0;
            int i11 = 0;
            int i12 = 0;
            for (int i13 = 0; i13 < iNDArrayIndexArr.length; i13++) {
                if (iNDArrayIndexArr[i13] instanceof NDArrayIndexAll) {
                    jArr[i11] = this.arr.size(i12);
                    jArr2[i11] = this.arr.stride(i12);
                    i11++;
                } else {
                    j += iNDArrayIndexArr[i13].offset() * this.arr.stride(i13);
                }
                i12++;
            }
            if (this.arr.isMatrix() && (iNDArrayIndexArr[0] instanceof PointIndex)) {
                jArr = ArrayUtil.reverseCopy(jArr);
                jArr2 = ArrayUtil.reverseCopy(jArr2);
            } else if (this.arr.isMatrix() && (iNDArrayIndexArr[0] instanceof PointIndex) && (iNDArrayIndexArr[1] instanceof IntervalIndex)) {
                jArr = new long[]{1, ((IntervalIndex) iNDArrayIndexArr[1]).length()};
            }
            this.strides = jArr2;
            this.shapes = jArr;
            this.offsets = jArr3;
            this.offset = j;
            return true;
        }
        if (i5 < 1 && i2 > 0 && i3 < 1 && i < 1 && i4 > 0) {
            int max2 = Math.max(this.arr.rank(), 2);
            long[] jArr4 = new long[max2];
            Arrays.fill(jArr4, 1L);
            long[] jArr5 = new long[max2];
            Arrays.fill(jArr5, this.arr.elementStride());
            long[] jArr6 = new long[max2];
            for (int i14 = 0; i14 < jArr4.length; i14++) {
                if (iNDArrayIndexArr[i14] instanceof NDArrayIndexAll) {
                    jArr4[i14] = this.arr.size(i14);
                    jArr5[i14] = this.arr.stride(i14);
                    jArr6[i14] = iNDArrayIndexArr[i14].offset();
                } else if (iNDArrayIndexArr[i14] instanceof IntervalIndex) {
                    jArr4[i14] = iNDArrayIndexArr[i14].length();
                    jArr5[i14] = iNDArrayIndexArr[i14].stride() * this.arr.stride(i14);
                    jArr6[i14] = iNDArrayIndexArr[i14].offset();
                }
            }
            this.shapes = jArr4;
            this.strides = jArr5;
            this.offsets = jArr6;
            this.offset = 0L;
            for (int i15 = 0; i15 < iNDArrayIndexArr.length; i15++) {
                this.offset += jArr6[i15] * (jArr5[i15] / iNDArrayIndexArr[i15].stride());
            }
            return true;
        }
        if (i5 >= 1 || i2 >= 1 || i3 <= 0 || i >= 1 || i4 <= 0) {
            return false;
        }
        int max3 = Math.max(this.arr.rank(), 2) + i3;
        long[] jArr7 = new long[max3];
        Arrays.fill(jArr7, 1L);
        long[] jArr8 = new long[max3];
        Arrays.fill(jArr8, this.arr.elementStride());
        long[] jArr9 = new long[max3];
        int i16 = 0;
        for (int i17 = 0; i17 < max3; i17++) {
            if (i17 >= iNDArrayIndexArr.length) {
                jArr7[i17] = this.arr.size(i16);
                jArr8[i17] = this.arr.stride(i16);
                i16++;
            } else if (!(iNDArrayIndexArr[i17] instanceof NewAxis) && (iNDArrayIndexArr[i17] instanceof NDArrayIndexAll)) {
                jArr7[i16] = this.arr.size(i16);
                jArr8[i16] = this.arr.stride(i16);
                i16++;
            }
        }
        this.shapes = jArr7;
        this.strides = jArr8;
        this.offsets = jArr9;
        for (int i18 = 0; i18 < iNDArrayIndexArr.length; i18++) {
            this.offset += jArr9[i18] * (jArr8[i18] / iNDArrayIndexArr[i18].stride());
        }
        return true;
    }

    public void exec(INDArrayIndex... iNDArrayIndexArr) {
        long[] shape = this.arr.shape();
        if (this.arr.isSparse()) {
            resolveFixedDimensionsCOO(iNDArrayIndexArr);
        }
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            INDArrayIndex iNDArrayIndex = iNDArrayIndexArr[i];
            if (iNDArrayIndex instanceof PointIndex) {
                if (this.arr.isVector() && iNDArrayIndexArr.length == 1) {
                    if (iNDArrayIndex.current() >= shape[i + 1]) {
                        throw new IllegalArgumentException("INDArrayIndex[" + i + "] is out of bounds (value: " + iNDArrayIndex.current() + ")");
                    }
                } else if (iNDArrayIndex.current() >= shape[i]) {
                    throw new IllegalArgumentException("INDArrayIndex[" + i + "] is out of bounds (value: " + iNDArrayIndex.current() + ")");
                }
            }
        }
        INDArrayIndex[] resolve = NDArrayIndex.resolve(this.arr.shapeInfoDataBuffer(), iNDArrayIndexArr);
        if (tryShortCircuit(resolve)) {
            return;
        }
        int i2 = 0;
        boolean z = false;
        int i3 = -1;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        ArrayList arrayList7 = new ArrayList();
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        ArrayList arrayList8 = new ArrayList();
        for (int i7 = 0; i7 < resolve.length; i7++) {
            INDArrayIndex iNDArrayIndex2 = resolve[i7];
            if (iNDArrayIndex2 instanceof NDArrayIndexAll) {
                z = true;
                if (i7 < this.arr.rank() && this.arr.size(i7) == 1) {
                    arrayList.add(Integer.valueOf(i7));
                }
                if (0 > 0 && i3 < 0) {
                    i3 = i7 - 1;
                }
            }
            if (iNDArrayIndex2 instanceof PointIndex) {
                arrayList7.add(Long.valueOf(iNDArrayIndex2.offset()));
                arrayList6.add(Long.valueOf(this.arr.stride(i6)));
                i4++;
                i5++;
                i6++;
                if (0 > 0 && i3 < 0) {
                    i3 = i7 - 1;
                }
            } else if (iNDArrayIndex2 instanceof NewAxis) {
                arrayList2.add(1L);
                arrayList4.add(0L);
                arrayList3.add(0L);
                arrayList8.add(Integer.valueOf(i7));
            } else if ((!(iNDArrayIndex2 instanceof IntervalIndex) || (iNDArrayIndex2 instanceof NDArrayIndexAll)) && !(iNDArrayIndex2 instanceof SpecifiedIndex)) {
                int i8 = i5;
                i5++;
                arrayList2.add(Long.valueOf(shape[i8]));
                int i9 = i6;
                i6++;
                arrayList3.add(Long.valueOf(this.arr.stride(i9)));
                arrayList4.add(Long.valueOf(iNDArrayIndex2.offset()));
            } else {
                if (iNDArrayIndex2 instanceof IntervalIndex) {
                    arrayList3.add(Long.valueOf(this.arr.stride(i6) * iNDArrayIndex2.stride()));
                    arrayList5.add(Long.valueOf(iNDArrayIndex2.stride()));
                    i2++;
                } else {
                    arrayList3.add(Long.valueOf(this.arr.stride(i6)));
                }
                arrayList2.add(Long.valueOf(iNDArrayIndex2.length()));
                if (iNDArrayIndex2 instanceof IntervalIndex) {
                    arrayList4.add(Long.valueOf(iNDArrayIndex2.offset()));
                } else {
                    arrayList4.add(Long.valueOf(iNDArrayIndex2.offset()));
                }
                i5++;
                i6++;
                if (0 > 0 && i3 < 0) {
                    i3 = i7 - 1;
                }
            }
        }
        while (i5 < shape.length) {
            if (Shape.isVector(shape)) {
                arrayList2.add(1L);
                i5++;
            } else {
                int i10 = i5;
                i5++;
                arrayList2.add(Long.valueOf(shape[i10]));
            }
        }
        int length = shape.length <= 2 ? shape.length : shape.length - i4;
        boolean z2 = (arrayList2.size() == arrayList3.size() || arrayList4.size() == arrayList2.size()) ? false : true;
        while (arrayList4.size() < length && z2) {
            arrayList4.add(0L);
        }
        while (arrayList2.size() < 2) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                arrayList2.add(0, 1L);
            } else {
                arrayList2.add(1L);
            }
        }
        while (i6 < arrayList2.size()) {
            int i11 = i6;
            i6++;
            arrayList3.add(Long.valueOf(this.arr.stride(i11)));
        }
        int size = arrayList4.size() - 1;
        while (arrayList4.size() > arrayList2.size()) {
            if (((Long) arrayList4.get(size)).longValue() == 0) {
                arrayList4.remove(arrayList4.size() - 1);
            }
            size--;
        }
        if (arrayList3.size() < arrayList4.size()) {
            arrayList3.addAll(arrayList6);
        }
        while (arrayList4.size() < arrayList2.size()) {
            if (Shape.isRowVectorShape(this.arr.shape())) {
                arrayList4.add(0, 0L);
            } else {
                arrayList4.add(0L);
            }
        }
        if (Shape.isMatrix(shape) && (resolve[0] instanceof PointIndex) && (resolve[1] instanceof NDArrayIndexAll)) {
            Collections.reverse(arrayList2);
        }
        if (this.arr.isMatrix() && (resolve[0] instanceof PointIndex) && (resolve[1] instanceof IntervalIndex)) {
            this.shapes = new long[2];
            this.shapes[0] = 1;
            this.shapes[1] = ((IntervalIndex) resolve[1]).length();
        } else {
            this.shapes = Longs.toArray(arrayList2);
        }
        boolean isColumnVectorShape = Shape.isColumnVectorShape(this.shapes);
        while (arrayList3.size() < arrayList4.size()) {
            if (isColumnVectorShape) {
                arrayList3.add(Long.valueOf(this.arr.elementStride()));
            } else {
                arrayList3.add(0, Long.valueOf(this.arr.elementStride()));
            }
        }
        this.strides = Longs.toArray(arrayList3);
        this.offsets = Longs.toArray(arrayList4);
        if (i4 <= 0 || arrayList6.isEmpty()) {
            this.offset = 0L;
        } else {
            if (0 >= 1) {
                while (arrayList6.size() < arrayList4.size()) {
                    arrayList6.add(1L);
                }
                for (int i12 = 0; i12 < arrayList3.size(); i12++) {
                    if (((Long) arrayList3.get(i12)).longValue() == 0 && !(resolve[i12] instanceof NewAxis) && i3 <= 0) {
                        arrayList6.set(i12, 0L);
                    }
                }
            }
            while (arrayList7.size() < arrayList6.size()) {
                arrayList7.add(0L);
            }
            if (!this.arr.isRowVector() || arrayList5.isEmpty() || ((Long) arrayList7.get(0)).longValue() != 0 || (resolve[1] instanceof IntervalIndex)) {
                this.offset = ArrayUtil.dotProductLong2(arrayList7, arrayList6);
            } else {
                this.offset = resolve[1].offset();
            }
        }
        if (i2 <= 0 || this.arr.rank() <= 2) {
            if (i2 <= 0 || !anyHaveStrideOne(resolve)) {
                this.offset += ArrayUtil.calcOffsetLong2(arrayList2, arrayList4, arrayList3) / Math.max(1, i2);
            } else {
                this.offset += ArrayUtil.calcOffsetLong2(arrayList2, arrayList4, arrayList3);
            }
        } else if ((!z || this.arr.size(0) == 1) && !(resolve[0] instanceof PointIndex)) {
            this.offset += ArrayUtil.dotProductLong2(arrayList4, arrayList3);
        } else {
            this.offset += ArrayUtil.dotProductLong2(arrayList4, arrayList3);
        }
        ArrayList arrayList9 = new ArrayList();
        for (int i13 = 0; i13 < Math.min(this.shapes.length, resolve.length); i13++) {
            if (this.shapes[i13] == 1 && (resolve[i13] instanceof SpecifiedIndex)) {
                arrayList9.add(Integer.valueOf(i13));
            }
        }
        if (arrayList9.isEmpty()) {
            return;
        }
        ArrayList arrayList10 = new ArrayList();
        ArrayList arrayList11 = new ArrayList();
        for (int i14 = 0; i14 < this.shapes.length; i14++) {
            if (!arrayList9.contains(Integer.valueOf(i14))) {
                arrayList10.add(Long.valueOf(this.shapes[i14]));
                arrayList11.add(Long.valueOf(this.strides[i14]));
            }
        }
        this.shapes = Longs.toArray(arrayList10);
        this.strides = Longs.toArray(arrayList11);
    }

    public void resolveFixedDimensionsCOO(INDArrayIndex... iNDArrayIndexArr) {
        this.fixed = new int[this.arr.rank()];
        int i = 0;
        for (int i2 = 0; i2 < iNDArrayIndexArr.length; i2++) {
            if (iNDArrayIndexArr[i2] instanceof PointIndex) {
                this.fixed[i] = 1;
                i++;
            }
            if ((iNDArrayIndexArr[i2] instanceof IntervalIndex) || (iNDArrayIndexArr[i2] instanceof NDArrayIndexAll)) {
                this.fixed[i] = 0;
                i++;
            }
            if (iNDArrayIndexArr[i2] instanceof SpecifiedIndex) {
                if (((SpecifiedIndex) iNDArrayIndexArr[i2]).getIndexes().length == 1) {
                    this.fixed[i] = 1;
                } else {
                    this.fixed[i] = 0;
                }
                i++;
            }
            if (iNDArrayIndexArr[i2] instanceof NewAxis) {
            }
        }
    }

    public void resolveSparseOffsetsCOO() {
    }

    private boolean anyHaveStrideOne(INDArrayIndex... iNDArrayIndexArr) {
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            if (iNDArrayIndex.stride() == 1) {
                return true;
            }
        }
        return false;
    }

    private boolean allIndexGreatherThanZero(INDArrayIndex... iNDArrayIndexArr) {
        for (INDArrayIndex iNDArrayIndex : iNDArrayIndexArr) {
            if (iNDArrayIndex.offset() == 0) {
                return false;
            }
        }
        return true;
    }

    public INDArray getArr() {
        return this.arr;
    }

    public int[] getFixed() {
        return this.fixed;
    }

    public int[] getPrependAxis() {
        return this.prependAxis;
    }

    public long[] getOffsets() {
        return this.offsets;
    }

    public long[] getShapes() {
        return this.shapes;
    }

    public long[] getStrides() {
        return this.strides;
    }

    public long getOffset() {
        return this.offset;
    }

    public void setArr(INDArray iNDArray) {
        this.arr = iNDArray;
    }

    public void setFixed(int[] iArr) {
        this.fixed = iArr;
    }

    public void setPrependAxis(int[] iArr) {
        this.prependAxis = iArr;
    }

    public void setOffsets(long[] jArr) {
        this.offsets = jArr;
    }

    public void setShapes(long[] jArr) {
        this.shapes = jArr;
    }

    public void setStrides(long[] jArr) {
        this.strides = jArr;
    }

    public void setOffset(long j) {
        this.offset = j;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ShapeOffsetResolution)) {
            return false;
        }
        ShapeOffsetResolution shapeOffsetResolution = (ShapeOffsetResolution) obj;
        if (!shapeOffsetResolution.canEqual(this)) {
            return false;
        }
        INDArray arr = getArr();
        INDArray arr2 = shapeOffsetResolution.getArr();
        if (arr == null) {
            if (arr2 != null) {
                return false;
            }
        } else if (!arr.equals(arr2)) {
            return false;
        }
        return Arrays.equals(getFixed(), shapeOffsetResolution.getFixed()) && Arrays.equals(getPrependAxis(), shapeOffsetResolution.getPrependAxis()) && Arrays.equals(getOffsets(), shapeOffsetResolution.getOffsets()) && Arrays.equals(getShapes(), shapeOffsetResolution.getShapes()) && Arrays.equals(getStrides(), shapeOffsetResolution.getStrides()) && getOffset() == shapeOffsetResolution.getOffset();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ShapeOffsetResolution;
    }

    public int hashCode() {
        INDArray arr = getArr();
        int hashCode = (((((((((((1 * 59) + (arr == null ? 43 : arr.hashCode())) * 59) + Arrays.hashCode(getFixed())) * 59) + Arrays.hashCode(getPrependAxis())) * 59) + Arrays.hashCode(getOffsets())) * 59) + Arrays.hashCode(getShapes())) * 59) + Arrays.hashCode(getStrides());
        long offset = getOffset();
        return (hashCode * 59) + ((int) ((offset >>> 32) ^ offset));
    }

    public String toString() {
        return "ShapeOffsetResolution(arr=" + getArr() + ", fixed=" + Arrays.toString(getFixed()) + ", prependAxis=" + Arrays.toString(getPrependAxis()) + ", offsets=" + Arrays.toString(getOffsets()) + ", shapes=" + Arrays.toString(getShapes()) + ", strides=" + Arrays.toString(getStrides()) + ", offset=" + getOffset() + ")";
    }
}
