package org.canova.image.loader;

import com.github.jaiimageio.impl.plugins.tiff.TIFFImageReaderSpi;
import com.github.jaiimageio.impl.plugins.tiff.TIFFImageWriterSpi;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.awt.image.Raster;
import java.awt.image.WritableRaster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import javax.imageio.ImageIO;
import javax.imageio.spi.IIORegistry;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/canova/image/loader/ImageLoader.class */
public class ImageLoader implements Serializable {
    private int width;
    private int height;
    private int channels;

    public ImageLoader() {
        this.width = -1;
        this.height = -1;
        this.channels = -1;
    }

    public ImageLoader(int i, int i2) {
        this.width = -1;
        this.height = -1;
        this.channels = -1;
        this.width = i;
        this.height = i2;
    }

    public ImageLoader(int i, int i2, int i3) {
        this.width = -1;
        this.height = -1;
        this.channels = -1;
        this.width = i;
        this.height = i2;
        this.channels = i3;
    }

    public INDArray asRowVector(File file) throws Exception {
        return this.channels == 3 ? toRaveledTensor(file) : ArrayUtil.toNDArray(flattenedImageFromFile(file));
    }

    public INDArray toRaveledTensor(File file) {
        try {
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
            INDArray raveledTensor = toRaveledTensor(bufferedInputStream);
            bufferedInputStream.close();
            return raveledTensor.ravel();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray toRaveledTensor(InputStream inputStream) {
        return toRgb(inputStream).ravel();
    }

    public INDArray toRgb(File file) {
        try {
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
            INDArray rgb = toRgb(bufferedInputStream);
            bufferedInputStream.close();
            return rgb;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void toBufferedImageRGB(INDArray iNDArray, BufferedImage bufferedImage) {
        if (iNDArray.rank() < 3) {
            throw new IllegalArgumentException("Arr must be 3d");
        }
        if (iNDArray.size(-2) > 0 && iNDArray.size(-1) > 0) {
            bufferedImage = toBufferedImage(bufferedImage.getScaledInstance(iNDArray.size(-2), iNDArray.size(-1), 4));
        }
        for (int i = 0; i < bufferedImage.getWidth(); i++) {
            for (int i2 = 0; i2 < bufferedImage.getHeight(); i2++) {
                bufferedImage.setRGB(i, i2, (1 << 24) | (iNDArray.slice(0).getInt(new int[]{i, i2}) << 16) | (iNDArray.slice(1).getInt(new int[]{i, i2}) << 8) | iNDArray.slice(2).getInt(new int[]{i, i2}));
            }
        }
    }

    public INDArray toRgb(InputStream inputStream) {
        try {
            BufferedImage read = ImageIO.read(inputStream);
            if (this.height > 0 && this.width > 0) {
                read = toBufferedImage(read.getScaledInstance(this.height, this.width, 4));
            }
            INDArray create = Nd4j.create(new int[]{3, this.height, this.width});
            for (int i = 0; i < read.getWidth(); i++) {
                for (int i2 = 0; i2 < read.getHeight(); i2++) {
                    int[] pixelData = getPixelData(read, i, i2);
                    for (int i3 = 0; i3 < pixelData.length; i3++) {
                        create.putScalar(new int[]{i3, i, i2}, pixelData[i3]);
                    }
                }
            }
            return create;
        } catch (IOException e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    private int[] getPixelData(BufferedImage bufferedImage, int i, int i2) {
        int rgb = bufferedImage.getRGB(i, i2);
        return new int[]{(rgb >> 16) & 255, (rgb >> 8) & 255, rgb & 255};
    }

    public INDArray asMatrix(InputStream inputStream) {
        if (this.channels == 3) {
            return toRgb(inputStream);
        }
        try {
            BufferedImage read = ImageIO.read(inputStream);
            if (this.height > 0 && this.width > 0) {
                read = toBufferedImage(read.getScaledInstance(this.height, this.width, 4));
            }
            Raster data = read.getData();
            int width = data.getWidth();
            int height = data.getHeight();
            int[][] iArr = new int[width][height];
            for (int i = 0; i < width; i++) {
                for (int i2 = 0; i2 < height; i2++) {
                    iArr[i][i2] = data.getSample(i, i2, 0);
                }
            }
            INDArray create = Nd4j.create(width, height);
            for (int i3 = 0; i3 < iArr.length; i3++) {
                for (int i4 = 0; i4 < iArr[i3].length; i4++) {
                    create.putScalar(new int[]{i3, i4}, iArr[i3][i4]);
                }
            }
            return create;
        } catch (IOException e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    public INDArray asRowVector(InputStream inputStream) {
        return asMatrix(inputStream).ravel();
    }

    public INDArray asImageMiniBatches(File file, int i, int i2) {
        try {
            return Nd4j.create(new int[]{i, i2, asMatrix(file).columns()});
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray asMatrix(File file) throws IOException {
        return ArrayUtil.toNDArray(fromFile(file));
    }

    public int[] flattenedImageFromFile(File file) throws Exception {
        return ArrayUtil.flatten(fromFile(file));
    }

    public int[][] fromFile(File file) throws IOException {
        BufferedImage read = ImageIO.read(file);
        if (this.height > 0 && this.width > 0) {
            read = toBufferedImage(read.getScaledInstance(this.height, this.width, 4));
        }
        Raster data = read.getData();
        int width = data.getWidth();
        int height = data.getHeight();
        int[][] iArr = new int[width][height];
        for (int i = 0; i < width; i++) {
            for (int i2 = 0; i2 < height; i2++) {
                iArr[i][i2] = data.getSample(i, i2, 0);
            }
        }
        return iArr;
    }

    public int[][][] fromFileMultipleChannels(File file) throws IOException {
        BufferedImage read = ImageIO.read(file);
        if (this.height > 0 && this.width > 0) {
            read = toBufferedImage(read.getScaledInstance(this.height, this.width, 4));
        }
        Raster data = read.getData();
        int width = data.getWidth();
        int height = data.getHeight();
        int[][][] iArr = new int[width][height][this.channels];
        for (int i = 0; i < width; i++) {
            for (int i2 = 0; i2 < height; i2++) {
                Color color = new Color(read.getRGB(i, i2));
                iArr[i][i2][0] = color.getRed();
                iArr[i][i2][1] = color.getBlue();
                iArr[i][i2][2] = color.getGreen();
            }
        }
        return iArr;
    }

    public static BufferedImage toImage(INDArray iNDArray) {
        BufferedImage bufferedImage = new BufferedImage(iNDArray.rows(), iNDArray.columns(), 2);
        WritableRaster raster = bufferedImage.getRaster();
        int[] iArr = new int[iNDArray.length()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = (int) iNDArray.getDouble(i);
        }
        raster.setDataElements(0, 0, iNDArray.rows(), iNDArray.columns(), iArr);
        return bufferedImage;
    }

    private static int[] rasterData(INDArray iNDArray) {
        int[] iArr = new int[iNDArray.length()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = (int) Math.round(((Double) iNDArray.getScalar(i).element()).doubleValue());
        }
        return iArr;
    }

    public static BufferedImage toBufferedImage(Image image) {
        if (image instanceof BufferedImage) {
            return (BufferedImage) image;
        }
        BufferedImage bufferedImage = new BufferedImage(image.getWidth((ImageObserver) null), image.getHeight((ImageObserver) null), 2);
        Graphics2D createGraphics = bufferedImage.createGraphics();
        createGraphics.drawImage(image, 0, 0, (ImageObserver) null);
        createGraphics.dispose();
        return bufferedImage;
    }

    public INDArray asRowVector(BufferedImage bufferedImage) {
        if (this.height > 0 && this.width > 0) {
            bufferedImage = toBufferedImage(bufferedImage.getScaledInstance(this.height, this.width, 4));
        }
        Raster data = bufferedImage.getData();
        int width = data.getWidth();
        int height = data.getHeight();
        int[][] iArr = new int[width][height];
        for (int i = 0; i < width; i++) {
            for (int i2 = 0; i2 < height; i2++) {
                iArr[i][i2] = data.getSample(i, i2, 0);
            }
        }
        return ArrayUtil.toNDArray(ArrayUtil.flatten(iArr));
    }

    public INDArray toRaveledTensor(BufferedImage bufferedImage) {
        try {
            if (this.height > 0 && this.width > 0) {
                bufferedImage = toBufferedImage(bufferedImage.getScaledInstance(this.height, this.width, 4));
            }
            INDArray create = Nd4j.create(new int[]{3, this.height, this.width});
            for (int i = 0; i < bufferedImage.getWidth(); i++) {
                for (int i2 = 0; i2 < bufferedImage.getHeight(); i2++) {
                    int[] pixelData = getPixelData(bufferedImage, i, i2);
                    for (int i3 = 0; i3 < pixelData.length; i3++) {
                        create.putScalar(new int[]{i, i2, i3}, pixelData[i]);
                    }
                }
            }
            return create.ravel();
        } catch (Exception e) {
            throw new RuntimeException("Unable to load image", e);
        }
    }

    static {
        IIORegistry defaultInstance = IIORegistry.getDefaultInstance();
        defaultInstance.registerServiceProvider(new TIFFImageWriterSpi());
        defaultInstance.registerServiceProvider(new TIFFImageReaderSpi());
    }
}
