package org.bigml.mimir.deepnet.network.yolo;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.cache.BundleExtractor;
import org.bigml.mimir.cache.CNNCache;
import org.bigml.mimir.cache.TensorflowLoadedFunction;
import org.bigml.mimir.concurrent.TensorflowModel;
import org.bigml.mimir.utils.Json;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;

/* loaded from: input_file:org/bigml/mimir/deepnet/network/yolo/TensorflowBoundingBoxPredictor.class */
public class TensorflowBoundingBoxPredictor extends TensorflowModel<BoundingBoxes> {
    private double _scoreThreshold;
    private double _iouThreshold;
    private String[] _classNames;
    private static final double DEFAULT_SCORE_THRESHOLD = 0.5d;
    private static final double DEFAULT_IOU_THRESHOLD = 0.5d;

    public TensorflowBoundingBoxPredictor(InputStream inputStream, int i, double d, double d2) {
        super(inputStream, i);
        this._classNames = null;
        for (Object obj : this._subNetworks) {
            TensorflowBoundingBoxPredictor tensorflowBoundingBoxPredictor = (TensorflowBoundingBoxPredictor) obj;
            tensorflowBoundingBoxPredictor.setScoreThreshold(d);
            tensorflowBoundingBoxPredictor.setIOUThreshold(d2);
        }
    }

    public TensorflowBoundingBoxPredictor(File file, int i, double d, double d2) {
        this(BundleExtractor.makeStream(file), i, d, d2);
    }

    public TensorflowBoundingBoxPredictor(File file) {
        this(file, 1, 0.5d, 0.5d);
    }

    public TensorflowBoundingBoxPredictor(String str) throws FileNotFoundException {
        this(CNNCache.getFeaturizer(str));
    }

    public void setScoreThreshold(double d) {
        this._scoreThreshold = d;
    }

    public void setIOUThreshold(double d) {
        this._iouThreshold = d;
    }

    public TensorflowBoundingBoxPredictor(TensorflowLoadedFunction tensorflowLoadedFunction) {
        super(tensorflowLoadedFunction);
        this._classNames = Json.getStringArray(this._settings.get("_classes"));
    }

    @Override // org.bigml.mimir.concurrent.TensorflowModel
    protected Map<String, Tensor> makeInputs(List<Object> list) {
        return makeImageTensor(list.get(0));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.bigml.mimir.concurrent.TensorflowModel
    protected BoundingBoxes parseResult(Map<String, Tensor> map) {
        TFloat32 tFloat32 = null;
        TInt32 tInt32 = null;
        TFloat32 tFloat322 = null;
        for (String str : map.keySet()) {
            Tensor tensor = map.get(str);
            if (tensor.dataType() == DataType.DT_INT32) {
                tInt32 = map.get(str);
            } else if (tensor.shape().asArray().length == 3) {
                tFloat32 = map.get(str);
            } else {
                if (tensor.dataType() != DataType.DT_FLOAT) {
                    throw new IllegalStateException("Output '" + str + "' with shape " + map.get(str).shape() + " unknown");
                }
                tFloat322 = map.get(str);
            }
        }
        float[][] fArr = StdArrays.array3dCopyOf(tFloat32)[0];
        int[] iArr = StdArrays.array2dCopyOf(tInt32)[0];
        float[] fArr2 = StdArrays.array2dCopyOf(tFloat322)[0];
        String[] strArr = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr[i] = this._classNames[iArr[i]];
        }
        return new BoundingBoxes(fArr, strArr, fArr2).nonMaxSuppression(this._scoreThreshold, this._iouThreshold);
    }

    @Override // org.bigml.mimir.concurrent.TensorflowModel
    protected TensorflowModel<BoundingBoxes>[] getWorkers(InputStream inputStream, int i) {
        TensorflowBoundingBoxPredictor[] tensorflowBoundingBoxPredictorArr = new TensorflowBoundingBoxPredictor[i];
        try {
            BundleExtractor bundleExtractor = new BundleExtractor(inputStream);
            for (int i2 = 0; i2 < i; i2++) {
                try {
                    tensorflowBoundingBoxPredictorArr[i2] = new TensorflowBoundingBoxPredictor(bundleExtractor.getModel());
                } finally {
                }
            }
            bundleExtractor.close();
            return tensorflowBoundingBoxPredictorArr;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.bigml.mimir.concurrent.TensorflowModel
    protected /* bridge */ /* synthetic */ BoundingBoxes parseResult(Map map) {
        return parseResult((Map<String, Tensor>) map);
    }
}
