package ai.djl.modality.cv.output;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.util.JsonSerializable;
import ai.djl.util.JsonUtils;
import ai.djl.util.RandomUtils;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;

/* loaded from: input_file:ai/djl/modality/cv/output/CategoryMask.class */
public class CategoryMask implements JsonSerializable {
    private static final long serialVersionUID = 1;
    private static final int COLOR_BLACK = -16777216;
    private static final Gson GSON = JsonUtils.builder().registerTypeAdapter(CategoryMask.class, new SegmentationSerializer()).create();
    private List<String> classes;
    private int[][] mask;

    /* loaded from: input_file:ai/djl/modality/cv/output/CategoryMask$SegmentationSerializer.class */
    public static final class SegmentationSerializer implements JsonSerializer<CategoryMask> {
        @Override // com.google.gson.JsonSerializer
        public JsonElement serialize(CategoryMask categoryMask, Type type, JsonSerializationContext jsonSerializationContext) {
            return jsonSerializationContext.serialize(categoryMask.getMask());
        }
    }

    public CategoryMask(List<String> list, int[][] iArr) {
        this.classes = list;
        this.mask = iArr;
    }

    public List<String> getClasses() {
        return this.classes;
    }

    public int[][] getMask() {
        return this.mask;
    }

    @Override // ai.djl.ndarray.BytesSupplier
    public ByteBuffer toByteBuffer() {
        return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8));
    }

    @Override // ai.djl.util.JsonSerializable
    public String toJson() {
        return GSON.toJson(this) + '\n';
    }

    public Image getMaskImage(Image image) {
        return image.getMask(this.mask);
    }

    public Image getMaskImage(Image image, int i) {
        int length = this.mask[0].length;
        int length2 = this.mask.length;
        int[][] iArr = new int[length2][length];
        for (int i2 = 0; i2 < length2; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                iArr[i2][i3] = this.mask[i2][i3] == i ? 1 : 0;
            }
        }
        return image.getMask(iArr);
    }

    public Image getBackgroundImage(Image image) {
        return getMaskImage(image, 0);
    }

    public void drawMask(Image image, int i) {
        drawMask(image, i, COLOR_BLACK);
    }

    public void drawMask(Image image, int i, int i2) {
        image.drawImage(getColorOverlay(generateColors(i2, i)), true);
    }

    public void drawMask(Image image, int i, int i2, int i3) {
        int[] iArr = new int[this.classes.size()];
        iArr[i] = (i2 & 16777215) | (i3 << 24);
        image.drawImage(getColorOverlay(iArr), true);
    }

    private Image getColorOverlay(int[] iArr) {
        int length = this.mask.length;
        int length2 = this.mask[0].length;
        int[] iArr2 = new int[length2 * length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                iArr2[(i * length2) + i2] = iArr[this.mask[i][i2]];
            }
        }
        return ImageFactory.getInstance().fromPixels(iArr2, length2, length);
    }

    private int[] generateColors(int i, int i2) {
        int[] iArr = new int[this.classes.size()];
        iArr[0] = i;
        for (int i3 = 1; i3 < this.classes.size(); i3++) {
            iArr[i3] = (i2 << 24) | (RandomUtils.nextInt(256) << 16) | (RandomUtils.nextInt(256) << 8) | RandomUtils.nextInt(256);
        }
        return iArr;
    }
}
