package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/shape/StridedSlice.class */
public class StridedSlice extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) StridedSlice.class);
    private long[] begin;
    private long[] end;
    private long[] strides;
    private int beginMask;
    private int endMask;
    private int ellipsisMask;
    private int newAxisMask;
    private int shrinkAxisMask;

    public StridedSlice() {
    }

    public StridedSlice(SameDiff sameDiff, SDVariable sDVariable, int[] iArr, int[] iArr2, int[] iArr3) {
        this(sameDiff, sDVariable, iArr, iArr2, iArr3, 0, 0, 0, 0, 0);
    }

    public StridedSlice(SameDiff sameDiff, SDVariable sDVariable, long[] jArr, long[] jArr2, long[] jArr3) {
        this(sameDiff, sDVariable, jArr, jArr2, jArr3, 0, 0, 0, 0, 0);
    }

    public StridedSlice(SameDiff sameDiff, SDVariable sDVariable, @NonNull long[] jArr, @NonNull long[] jArr2, @NonNull long[] jArr3, int i, int i2, int i3, int i4, int i5) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable});
        if (jArr == null) {
            throw new NullPointerException("begin is marked non-null but is null");
        }
        if (jArr2 == null) {
            throw new NullPointerException("end is marked non-null but is null");
        }
        if (jArr3 == null) {
            throw new NullPointerException("strides is marked non-null but is null");
        }
        this.begin = jArr;
        this.end = jArr2;
        this.strides = jArr3;
        this.beginMask = i;
        this.endMask = i2;
        this.ellipsisMask = i3;
        this.newAxisMask = i4;
        this.shrinkAxisMask = i5;
        addArguments();
    }

    public StridedSlice(SameDiff sameDiff, SDVariable sDVariable, @NonNull int[] iArr, @NonNull int[] iArr2, @NonNull int[] iArr3, int i, int i2, int i3, int i4, int i5) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable});
        if (iArr == null) {
            throw new NullPointerException("begin is marked non-null but is null");
        }
        if (iArr2 == null) {
            throw new NullPointerException("end is marked non-null but is null");
        }
        if (iArr3 == null) {
            throw new NullPointerException("strides is marked non-null but is null");
        }
        this.begin = ArrayUtil.toLongArray(iArr);
        this.end = ArrayUtil.toLongArray(iArr2);
        this.strides = ArrayUtil.toLongArray(iArr3);
        this.beginMask = i;
        this.endMask = i2;
        this.ellipsisMask = i3;
        this.newAxisMask = i4;
        this.shrinkAxisMask = i5;
        addArguments();
    }

    public StridedSlice(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, int i, int i2, int i3, int i4, int i5) {
        this(iNDArray, ArrayUtil.toLongArray(iArr), ArrayUtil.toLongArray(iArr2), ArrayUtil.toLongArray(iArr3), i, i2, i3, i4, i5);
    }

    public StridedSlice(INDArray iNDArray, long[] jArr, long[] jArr2, long[] jArr3, int i, int i2, int i3, int i4, int i5) {
        addInputArgument(iNDArray);
        this.begin = jArr;
        this.end = jArr2;
        this.strides = jArr3;
        this.beginMask = i;
        this.endMask = i2;
        this.ellipsisMask = i3;
        this.newAxisMask = i4;
        this.shrinkAxisMask = i5;
        addArguments();
    }

    private void addArguments() {
        addIArgument(this.beginMask);
        addIArgument(this.ellipsisMask);
        addIArgument(this.endMask);
        addIArgument(this.newAxisMask);
        addIArgument(this.shrinkAxisMask);
        addIArgument(this.begin);
        addIArgument(this.end);
        addIArgument(this.strides);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "stridedslice";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "StridedSlice";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public void assertValidForExecution() {
        if (numInputArguments() != 1 && numInputArguments() != 3 && numInputArguments() != 4) {
            throw new ND4JIllegalStateException("Num input arguments must be 1 3 or 4.");
        }
        if (numIArguments() < 5) {
            throw new ND4JIllegalStateException("Number of integer arguments must >= 5");
        }
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        nodeDef.getInput(1);
        nodeDef.getInput(2);
        nodeDef.getInput(3);
        AttrValue attrOrThrow = nodeDef.getAttrOrThrow("begin_mask");
        AttrValue attrOrThrow2 = nodeDef.getAttrOrThrow("ellipsis_mask");
        AttrValue attrOrThrow3 = nodeDef.getAttrOrThrow("end_mask");
        AttrValue attrOrThrow4 = nodeDef.getAttrOrThrow("new_axis_mask");
        AttrValue attrOrThrow5 = nodeDef.getAttrOrThrow("shrink_axis_mask");
        this.beginMask = (int) attrOrThrow.getI();
        this.ellipsisMask = (int) attrOrThrow2.getI();
        this.endMask = (int) attrOrThrow3.getI();
        this.newAxisMask = (int) attrOrThrow4.getI();
        this.shrinkAxisMask = (int) attrOrThrow5.getI();
        addIArgument(this.beginMask);
        addIArgument(this.ellipsisMask);
        addIArgument(this.endMask);
        addIArgument(this.newAxisMask);
        addIArgument(this.shrinkAxisMask);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().tfInputPosition(1).propertyNames(new String[]{"begin"}).build();
        PropertyMapping build2 = PropertyMapping.builder().tfInputPosition(2).propertyNames(new String[]{"end"}).build();
        PropertyMapping build3 = PropertyMapping.builder().tfInputPosition(3).propertyNames(new String[]{"strides"}).build();
        PropertyMapping build4 = PropertyMapping.builder().tfAttrName("begin_mask").propertyNames(new String[]{"beginMask"}).build();
        PropertyMapping build5 = PropertyMapping.builder().tfAttrName("ellipsis_mask").propertyNames(new String[]{"ellipsisMask"}).build();
        PropertyMapping build6 = PropertyMapping.builder().tfAttrName("end_mask").propertyNames(new String[]{"endMask"}).build();
        PropertyMapping build7 = PropertyMapping.builder().tfAttrName("new_axis_mask").propertyNames(new String[]{"newAxisMask"}).build();
        PropertyMapping build8 = PropertyMapping.builder().tfAttrName("shrink_axis_mask").propertyNames(new String[]{"shrinkAxisMask"}).build();
        hashMap2.put("begin", build);
        hashMap2.put("end", build2);
        hashMap2.put("strides", build3);
        hashMap2.put("beginMask", build4);
        hashMap2.put("ellipsisMask", build5);
        hashMap2.put("endMask", build6);
        hashMap2.put("newAxisMask", build7);
        hashMap2.put("shrinkAxisMask", build8);
        hashMap.put(tensorflowName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return args().length == 1 ? new StridedSliceBp(this.sameDiff, arg(), list.get(0), this.begin, this.end, this.strides, this.beginMask, this.endMask, this.ellipsisMask, this.newAxisMask, this.shrinkAxisMask).outputs() : new StridedSliceBp(this.sameDiff, arg(), list.get(0), arg(1), arg(2), arg(3), this.beginMask, this.endMask, this.ellipsisMask, this.newAxisMask, this.shrinkAxisMask).outputs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && (list.size() == 1 || list.size() == 4), "Expected 1 or 4 input datatypes for %s, got %s", getClass(), list);
        return Collections.singletonList(list.get(0));
    }
}
