package org.deeplearning4j.nn.layers.objdetect;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.class */
public class Yolo2OutputLayer extends AbstractLayer<org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer> implements Serializable, IOutputLayer {
    private static final Gradient EMPTY_GRADIENT = new DefaultGradient();
    protected INDArray labels;
    private double fullNetRegTerm;
    private double score;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer$IOURet.class */
    public static class IOURet {
        private INDArray iou;
        private INDArray dIOU_dxy;
        private INDArray dIOU_dwh;

        public IOURet(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
            this.iou = iNDArray;
            this.dIOU_dxy = iNDArray2;
            this.dIOU_dwh = iNDArray3;
        }

        public INDArray getIou() {
            return this.iou;
        }

        public INDArray getDIOU_dxy() {
            return this.dIOU_dxy;
        }

        public INDArray getDIOU_dwh() {
            return this.dIOU_dwh;
        }

        public void setIou(INDArray iNDArray) {
            this.iou = iNDArray;
        }

        public void setDIOU_dxy(INDArray iNDArray) {
            this.dIOU_dxy = iNDArray;
        }

        public void setDIOU_dwh(INDArray iNDArray) {
            this.dIOU_dwh = iNDArray;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof IOURet)) {
                return false;
            }
            IOURet iOURet = (IOURet) obj;
            if (!iOURet.canEqual(this)) {
                return false;
            }
            INDArray iou = getIou();
            INDArray iou2 = iOURet.getIou();
            if (iou == null) {
                if (iou2 != null) {
                    return false;
                }
            } else if (!iou.equals(iou2)) {
                return false;
            }
            INDArray dIOU_dxy = getDIOU_dxy();
            INDArray dIOU_dxy2 = iOURet.getDIOU_dxy();
            if (dIOU_dxy == null) {
                if (dIOU_dxy2 != null) {
                    return false;
                }
            } else if (!dIOU_dxy.equals(dIOU_dxy2)) {
                return false;
            }
            INDArray dIOU_dwh = getDIOU_dwh();
            INDArray dIOU_dwh2 = iOURet.getDIOU_dwh();
            return dIOU_dwh == null ? dIOU_dwh2 == null : dIOU_dwh.equals(dIOU_dwh2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof IOURet;
        }

        public int hashCode() {
            INDArray iou = getIou();
            int hashCode = (1 * 59) + (iou == null ? 43 : iou.hashCode());
            INDArray dIOU_dxy = getDIOU_dxy();
            int hashCode2 = (hashCode * 59) + (dIOU_dxy == null ? 43 : dIOU_dxy.hashCode());
            INDArray dIOU_dwh = getDIOU_dwh();
            return (hashCode2 * 59) + (dIOU_dwh == null ? 43 : dIOU_dwh.hashCode());
        }

        public String toString() {
            return "Yolo2OutputLayer.IOURet(iou=" + getIou() + ", dIOU_dxy=" + getDIOU_dxy() + ", dIOU_dwh=" + getDIOU_dwh() + ")";
        }
    }

    public Yolo2OutputLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return new Pair<>(EMPTY_GRADIENT, computeBackpropGradientAndScore(layerWorkspaceMgr, false, false));
    }

    private INDArray computeBackpropGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr, boolean z, boolean z2) {
        assertInputSet(true);
        Preconditions.checkState(this.labels != null, "Cannot calculate gradients/score: labels are null");
        Preconditions.checkState(this.labels.rank() == 4, "Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w] but got rank %s labels array with shape %s", Integer.valueOf(this.labels.rank()), this.labels.shape());
        double lambdaCoord = layerConf().getLambdaCoord();
        double lambdaNoObj = layerConf().getLambdaNoObj();
        long size = this.input.size(0);
        long size2 = this.input.size(2);
        long size3 = this.input.size(3);
        int size4 = (int) layerConf().getBoundingBoxes().size(0);
        int size5 = ((int) this.labels.size(1)) - 4;
        INDArray castTo = this.labels.castTo(this.input.dataType());
        long[] jArr = {size, size2, size3};
        Preconditions.checkState(castTo.rank() == 4, "Expected labels array to be rank 4 with shape [minibatch, 4+numClasses, H, W]. Got labels array with shape %ndShape", castTo);
        Preconditions.checkState(castTo.size(1) > 0, "Invalid labels array: labels.size(1) must be > 4. labels array should be rank 4 with shape [minibatch, 4+numClasses, H, W]. Got labels array with shape %ndShape", castTo);
        INDArray iNDArray = castTo.get(NDArrayIndex.all(), NDArrayIndex.interval(4L, castTo.size(1)), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray sum = iNDArray.sum(Nd4j.createUninitialized(this.input.dataType(), jArr, 'c'), 1);
        INDArray castTo2 = sum.castTo(DataType.BOOL);
        INDArray iNDArray2 = castTo.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray3 = castTo.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray muli = iNDArray2.add(iNDArray3).muli(Double.valueOf(0.5d));
        INDArray dup = muli.dup(muli.ordering());
        dup.subi(Transforms.floor(dup, true));
        INDArray sqrt = Transforms.sqrt(iNDArray3.sub(iNDArray2), false);
        long[] jArr2 = {size, size4 * (5 + size5), size2, size3};
        ArrayUtil.prodLong(size, size4, 5 + size5, size2, size3);
        Preconditions.checkState(Arrays.equals(jArr2, this.input.shape()), "Unable to reshape input - input array shape does not match expected shape. Expected input shape [minibatch, B*(5+C), H, W]=%s but got input of shape %ndShape. This may be due to an incorrect nOut (layer size/channels) for the last convolutional layer in the network. nOut of the last layer must be B*(5+C) where B is the number of bounding boxes, and C is the number of object classes. Expected B=%s from network configuration and C=%s from labels", jArr2, this.input, Integer.valueOf(size4), Integer.valueOf(size5));
        INDArray reshape = this.input.dup('c').reshape('c', size, size4, 5 + size5, size2, size3);
        INDArray iNDArray4 = reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(5, 5 + size5), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray5 = reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray sigmoid = Transforms.sigmoid(iNDArray5, true);
        INDArray exp = Transforms.exp(reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(2, 4), NDArrayIndex.all(), NDArrayIndex.all()), true);
        Broadcast.mul(exp, layerConf().getBoundingBoxes().castTo(exp.dataType()), exp, 1, 2);
        INDArray sqrt2 = Transforms.sqrt(exp, true);
        IOURet calculateIOULabelPredicted = calculateIOULabelPredicted(iNDArray2, iNDArray3, exp, sigmoid, sum, castTo2);
        INDArray iou = calculateIOULabelPredicted.getIou();
        INDArray create = Nd4j.create(DataType.BOOL, iou.shape(), 'c');
        Nd4j.exec(new IsMax(iou, create, 1));
        Nd4j.exec(new BroadcastMulOp(create, castTo2, create, 0, 2, 3));
        INDArray not = Transforms.not(create);
        INDArray castTo3 = create.castTo(this.input.dataType());
        INDArray mul = iou.mul(castTo3);
        INDArray iNDArray6 = reshape.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(4L), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray sigmoid2 = Transforms.sigmoid(iNDArray6, true);
        INDArray reshape2 = castTo3.reshape(size * size4 * size2 * size3, 1L);
        INDArray rsub = reshape2.rsub(Double.valueOf(1.0d));
        INDArray reshape3 = sigmoid.permute(0, 1, 3, 4, 2).dup('c').reshape('c', size * size4 * size2 * size3, 2);
        INDArray createUninitialized = Nd4j.createUninitialized(this.input.dataType(), new long[]{size, size4, 2, size2, size3}, 'c');
        for (int i = 0; i < size4; i++) {
            createUninitialized.get(NDArrayIndex.all(), NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()).assign(dup);
        }
        INDArray reshape4 = createUninitialized.permute(0, 1, 3, 4, 2).dup('c').reshape('c', size * size4 * size2 * size3, 2);
        INDArray dup2 = sqrt2.permute(0, 1, 3, 4, 2).dup('c').reshape(size * size4 * size2 * size3, 2L).dup('c');
        INDArray createUninitialized2 = Nd4j.createUninitialized(this.input.dataType(), new long[]{size, size4, 2, size2, size3}, 'c');
        for (int i2 = 0; i2 < size4; i2++) {
            createUninitialized2.get(NDArrayIndex.all(), NDArrayIndex.point(i2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()).assign(sqrt);
        }
        INDArray dup3 = createUninitialized2.permute(0, 1, 3, 4, 2).dup('c').reshape(size * size4 * size2 * size3, 2L).dup('c');
        INDArray reshape5 = mul.dup('c').reshape('c', size * size4 * size2 * size3, 1);
        INDArray dup4 = sigmoid2.dup('c').reshape('c', size * size4 * size2 * size3, 1).dup('c');
        INDArray dup5 = iNDArray6.dup('c').reshape('c', size * size4 * size2 * size3, 1).dup('c');
        INDArray reshape6 = iNDArray4.permute(0, 1, 3, 4, 2).dup('c').reshape('c', size * size4 * size2 * size3, size5);
        INDArray createUninitialized3 = Nd4j.createUninitialized(this.input.dataType(), new long[]{size, size4, size5, size2, size3}, 'c');
        for (int i3 = 0; i3 < size4; i3++) {
            createUninitialized3.get(NDArrayIndex.all(), NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()).assign(iNDArray);
        }
        INDArray reshape7 = createUninitialized3.permute(0, 1, 3, 4, 2).dup('c').reshape('c', size * size4 * size2 * size3, size5);
        LossL2 lossL2 = new LossL2();
        ActivationIdentity activationIdentity = new ActivationIdentity();
        if (z2) {
            INDArray sum2 = layerConf().getLossPositionScale().computeScoreArray(reshape4, reshape3, activationIdentity, reshape2).addi(layerConf().getLossPositionScale().computeScoreArray(dup3, dup2, activationIdentity, reshape2)).muli(Double.valueOf(lambdaCoord)).addi(lossL2.computeScoreArray(reshape5, dup4, activationIdentity, reshape2)).addi(lossL2.computeScoreArray(reshape5, dup4, activationIdentity, rsub).muli(Double.valueOf(lambdaNoObj)).muli(Double.valueOf(lambdaNoObj))).addi(layerConf().getLossClassPredictions().computeScoreArray(reshape7, reshape6, new ActivationSoftmax(), reshape2)).dup('c').reshape('c', size, size4 * size2 * size3).sum(true, 1);
            if (this.fullNetRegTerm > 0.0d) {
                sum2.addi(Double.valueOf(this.fullNetRegTerm));
            }
            return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, sum2);
        }
        this.score = (lambdaCoord * (layerConf().getLossPositionScale().computeScore(reshape4, reshape3, activationIdentity, reshape2, false) + layerConf().getLossPositionScale().computeScore(dup3, dup2, activationIdentity, reshape2, false))) + lossL2.computeScore(reshape5, dup4, activationIdentity, reshape2, false) + (lambdaNoObj * lossL2.computeScore(reshape5, dup4, activationIdentity, rsub, false)) + layerConf().getLossClassPredictions().computeScore(reshape7, reshape6, new ActivationSoftmax(), reshape2, false);
        this.score /= getInputMiniBatchSize();
        this.score += this.fullNetRegTerm;
        if (z) {
            return null;
        }
        INDArray createUninitialized4 = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.input.dataType(), this.input.shape(), 'c');
        INDArray newShapeNoCopy = Shape.newShapeNoCopy(createUninitialized4, new long[]{size, size4, 5 + size5, size2, size3}, false);
        INDArray iNDArray7 = newShapeNoCopy.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(5, 5 + size5), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray8 = newShapeNoCopy.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray9 = newShapeNoCopy.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(2, 4), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray10 = newShapeNoCopy.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(4L), NDArrayIndex.all(), NDArrayIndex.all());
        iNDArray7.assign(layerConf().getLossClassPredictions().computeGradient(reshape7, reshape6, new ActivationSoftmax(), reshape2).dup('c').reshape(size, size4, size2, size3, size5).permute(0, 1, 4, 2, 3).dup('c'));
        INDArray computeGradient = layerConf().getLossPositionScale().computeGradient(reshape4, reshape3, activationIdentity, reshape2);
        computeGradient.muli(Double.valueOf(lambdaCoord));
        iNDArray8.assign(new ActivationSigmoid().backprop(iNDArray5.dup(), computeGradient.dup('c').reshape('c', size, size4, size2, size3, 2).permute(0, 1, 4, 2, 3)).getFirst());
        INDArray permute = layerConf().getLossPositionScale().computeGradient(dup3, dup2, activationIdentity, reshape2).muli(Double.valueOf(0.5d)).divi(dup2).dup('c').reshape(size, size4, size2, size3, 2).permute(0, 1, 4, 2, 3);
        permute.muli(exp);
        permute.muli(Double.valueOf(lambdaCoord));
        iNDArray9.assign(permute);
        iNDArray10.assign(new ActivationSigmoid().backprop(dup5, lossL2.computeGradient(reshape5, dup4, activationIdentity, reshape2).addi(lossL2.computeGradient(reshape5, dup4, activationIdentity, rsub).muli(Double.valueOf(lambdaNoObj)))).getFirst().dup('c').reshape('c', size, size4, size2, size3));
        INDArray muli2 = iou.subi(sigmoid2).muli(Double.valueOf(2.0d)).muli(not.castTo(this.input.dataType()).muli(Double.valueOf(lambdaNoObj)).addi(castTo3));
        INDArray createUninitialized5 = Nd4j.createUninitialized(calculateIOULabelPredicted.dIOU_dxy.dataType(), calculateIOULabelPredicted.dIOU_dxy.shape(), calculateIOULabelPredicted.dIOU_dxy.ordering());
        Broadcast.mul(calculateIOULabelPredicted.dIOU_dxy, muli2, createUninitialized5, 0, 1, 3, 4);
        INDArray createUninitialized6 = Nd4j.createUninitialized(calculateIOULabelPredicted.dIOU_dwh.dataType(), calculateIOULabelPredicted.dIOU_dwh.shape(), calculateIOULabelPredicted.dIOU_dwh.ordering());
        Broadcast.mul(calculateIOULabelPredicted.dIOU_dwh, muli2, createUninitialized6, 0, 1, 3, 4);
        INDArray muli3 = createUninitialized6.muli(exp);
        INDArray first = new ActivationSigmoid().backprop(iNDArray5, createUninitialized5).getFirst();
        Broadcast.mul(muli3, castTo3, muli3, 0, 1, 3, 4);
        Broadcast.mul(first, castTo3, first, 0, 1, 3, 4);
        iNDArray9.addi(muli3);
        iNDArray8.addi(first);
        return createUninitialized4;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        return YoloUtils.activate(layerConf().getBoundingBoxes(), this.input, layerWorkspaceMgr);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m5783clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public boolean needsLabels() {
        return true;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.fullNetRegTerm = d;
        computeBackpropGradientAndScore(layerWorkspaceMgr, true, false);
        return score();
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    private static IOURet calculateIOULabelPredicted(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, INDArray iNDArray6) {
        long size = iNDArray.size(0);
        long size2 = iNDArray.size(2);
        long size3 = iNDArray.size(3);
        long size4 = iNDArray3.size(1);
        INDArray sub = iNDArray2.sub(iNDArray);
        long size5 = iNDArray.size(2);
        long size6 = iNDArray.size(3);
        INDArray linspace = Nd4j.linspace(0L, size6 - 1, size6, iNDArray3.dataType());
        INDArray linspace2 = Nd4j.linspace(0L, size5 - 1, size5, iNDArray3.dataType());
        INDArray createUninitialized = Nd4j.createUninitialized(iNDArray3.dataType(), new long[]{2, size5, size6}, 'c');
        INDArray iNDArray7 = createUninitialized.get(NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all());
        INDArray iNDArray8 = createUninitialized.get(NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all());
        Broadcast.copy(iNDArray7, linspace, iNDArray7, 1);
        Broadcast.copy(iNDArray8, linspace2, iNDArray8, 0);
        INDArray ulike = iNDArray4.ulike();
        Broadcast.add(iNDArray4, createUninitialized, ulike, 2, 3, 4);
        INDArray mul = iNDArray3.mul(Double.valueOf(0.5d));
        INDArray rsub = mul.rsub(ulike);
        INDArray add = mul.add(ulike);
        INDArray ulike2 = rsub.ulike();
        Broadcast.max(rsub, iNDArray, ulike2, 0, 2, 3, 4);
        INDArray ulike3 = add.ulike();
        Broadcast.min(add, iNDArray2, ulike3, 0, 2, 3, 4);
        INDArray sub2 = ulike3.sub(ulike2);
        INDArray prod = sub2.prod(2);
        Broadcast.mul(prod, iNDArray5, prod, 0, 2, 3);
        INDArray createUninitialized2 = Nd4j.createUninitialized(DataType.BOOL, ulike2.shape(), ulike2.ordering());
        INDArray createUninitialized3 = Nd4j.createUninitialized(DataType.BOOL, ulike2.shape(), ulike2.ordering());
        Broadcast.lte(add, iNDArray, createUninitialized2, 0, 2, 3, 4);
        Broadcast.gte(rsub, iNDArray2, createUninitialized3, 0, 2, 3, 4);
        INDArray not = Transforms.not(Transforms.or(Transforms.or(createUninitialized2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()), createUninitialized2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all())), Transforms.or(createUninitialized3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()), createUninitialized3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()))));
        Broadcast.mul(not, iNDArray6, not, 0, 2, 3);
        INDArray castTo = not.castTo(iNDArray3.dataType());
        prod.muli(castTo);
        INDArray prod2 = iNDArray3.prod(2);
        Broadcast.mul(prod2, iNDArray5, prod2, 0, 2, 3);
        INDArray add2 = Broadcast.add(prod2, sub.prod(1), prod2.dup(), 0, 2, 3);
        add2.subi(prod);
        add2.muli(castTo);
        INDArray div = prod.div(add2);
        BooleanIndexing.replaceWhere(div, Double.valueOf(0.0d), Conditions.isNan());
        Broadcast.mul(div, iNDArray5, div, 0, 2, 3);
        INDArray createUninitialized4 = Nd4j.createUninitialized(DataType.BOOL, ulike2.shape(), ulike2.ordering());
        Broadcast.gt(rsub, iNDArray, createUninitialized4, 0, 2, 3, 4);
        INDArray castTo2 = createUninitialized4.castTo(iNDArray3.dataType());
        INDArray createUninitialized5 = Nd4j.createUninitialized(DataType.BOOL, ulike2.shape(), ulike2.ordering());
        Broadcast.lt(add, iNDArray2, createUninitialized5, 0, 2, 3, 4);
        INDArray castTo3 = createUninitialized5.castTo(iNDArray3.dataType());
        INDArray sub3 = castTo3.sub(castTo2);
        INDArray muli = castTo3.addi(castTo2).muli(Double.valueOf(0.5d));
        sub3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()).muli(sub2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()));
        sub3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()).muli(sub2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()));
        muli.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()).muli(sub2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()));
        muli.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()).muli(sub2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()));
        INDArray add3 = add2.add(prod);
        INDArray mul2 = add2.mul(add2);
        INDArray div2 = add3.div(mul2);
        BooleanIndexing.replaceWhere(div2, Double.valueOf(0.0d), Conditions.isNan());
        INDArray createUninitialized6 = Nd4j.createUninitialized(iNDArray3.dataType(), new long[]{size, size4, 2, size2, size3}, 'c');
        Broadcast.mul(sub3, div2, createUninitialized6, 0, 1, 3, 4);
        INDArray createUninitialized7 = Nd4j.createUninitialized(iNDArray3.dataType(), new long[]{size, size4, 2, size2, size3}, iNDArray3.ordering());
        createUninitialized7.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()).assign(iNDArray3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()));
        createUninitialized7.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(1L), NDArrayIndex.all(), NDArrayIndex.all()).assign(iNDArray3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0L), NDArrayIndex.all(), NDArrayIndex.all()));
        INDArray ulike4 = createUninitialized7.ulike();
        Broadcast.mul(createUninitialized7, prod, ulike4, 0, 1, 3, 4);
        INDArray createUninitialized8 = Nd4j.createUninitialized(createUninitialized7.dataType(), new long[]{size, size4, 2, size2, size3}, 'c');
        Broadcast.mul(muli, add3, createUninitialized8, 0, 1, 3, 4);
        createUninitialized8.subi(ulike4);
        Broadcast.div(createUninitialized8, mul2, createUninitialized8, 0, 1, 3, 4);
        BooleanIndexing.replaceWhere(createUninitialized8, Double.valueOf(0.0d), Conditions.isNan());
        return new IOURet(div, createUninitialized6, createUninitialized8);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.fullNetRegTerm = d;
        return computeBackpropGradientAndScore(layerWorkspaceMgr, false, true);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
    }

    public List<DetectedObject> getPredictedObjects(INDArray iNDArray, double d) {
        return YoloUtils.getPredictedObjects(layerConf().getBoundingBoxes(), iNDArray, d, 0.0d);
    }

    public INDArray getConfidenceMatrix(INDArray iNDArray, int i, int i2) {
        return iNDArray.get(NDArrayIndex.point(i), NDArrayIndex.point(4 + (i2 * 5)), NDArrayIndex.all(), NDArrayIndex.all());
    }

    public INDArray getProbabilityMatrix(INDArray iNDArray, int i, int i2) {
        return iNDArray.get(NDArrayIndex.point(i), NDArrayIndex.point((5 * layerConf().getBoundingBoxes().size(0)) + i2), NDArrayIndex.all(), NDArrayIndex.all());
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray getLabels() {
        return this.labels;
    }
}
