package org.jetbrains.kotlinx.dl.visualization.letsplot;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jetbrains.letsPlot.Figure;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.kotlinx.dl.api.core.TrainableModel;
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D;

/* compiled from: PlotConv2D.kt */
@Metadata(mv = {1, 5, 1}, k = 2, xi = 48, d1 = {"��6\n��\n\u0002\u0010\u0015\n\u0002\b\u0005\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010 \n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0014\n��\u001a,\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\t2\b\b\u0002\u0010\n\u001a\u00020\u000b2\b\b\u0002\u0010\f\u001a\u00020\r2\b\b\u0002\u0010\u000e\u001a\u00020\r\u001a:\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u00070\u00102\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\b\b\u0002\u0010\n\u001a\u00020\u000b2\b\b\u0002\u0010\f\u001a\u00020\r2\b\b\u0002\u0010\u000e\u001a\u00020\r\"\u0014\u0010��\u001a\u00020\u0001X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b\u0002\u0010\u0003\"\u0014\u0010\u0004\u001a\u00020\u0001X\u0080\u0004¢\u0006\b\n��\u001a\u0004\b\u0005\u0010\u0003¨\u0006\u0015"}, d2 = {"ACTIVATION_LAYERS_PERMUTATION", "", "getACTIVATION_LAYERS_PERMUTATION", "()[I", "FILTER_LAYERS_PERMUTATION", "getFILTER_LAYERS_PERMUTATION", "filtersPlot", "Ljetbrains/letsPlot/Figure;", "conv2DLayer", "Lorg/jetbrains/kotlinx/dl/api/core/layer/convolutional/Conv2D;", "plotFeature", "Lorg/jetbrains/kotlinx/dl/visualization/letsplot/PlotFeature;", "imageSize", "", "columns", "modelActivationOnLayersPlot", "", "model", "Lorg/jetbrains/kotlinx/dl/api/core/TrainableModel;", "x", "", "visualization"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/visualization/letsplot/PlotConv2DKt.class */
public final class PlotConv2DKt {

    @NotNull
    private static final int[] FILTER_LAYERS_PERMUTATION = {1, 0, 2, 3};

    @NotNull
    private static final int[] ACTIVATION_LAYERS_PERMUTATION = {2, 1, 0, 3};

    @NotNull
    public static final int[] getFILTER_LAYERS_PERMUTATION() {
        return FILTER_LAYERS_PERMUTATION;
    }

    @NotNull
    public static final int[] getACTIVATION_LAYERS_PERMUTATION() {
        return ACTIVATION_LAYERS_PERMUTATION;
    }

    @NotNull
    public static final Figure filtersPlot(@NotNull Conv2D conv2D, @NotNull PlotFeature plotFeature, int i, int i2) {
        Intrinsics.checkNotNullParameter(conv2D, "conv2DLayer");
        Intrinsics.checkNotNullParameter(plotFeature, "plotFeature");
        Object[] array = conv2D.getWeights().values().toArray(new Object[0]);
        if (array == null) {
            throw new NullPointerException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        final float[][][][] fArr = (float[][][][]) ((Object[][]) array)[0];
        int[] extractXYInputOutputAxeSizes = PlotHelpersKt.extractXYInputOutputAxeSizes(fArr, FILTER_LAYERS_PERMUTATION);
        List<Pair<Integer, Integer>> cartesianProductIndices = PlotHelpersKt.cartesianProductIndices(extractXYInputOutputAxeSizes[2], extractXYInputOutputAxeSizes[3]);
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(cartesianProductIndices, 10));
        Iterator<T> it = cartesianProductIndices.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            final int intValue = ((Number) pair.component1()).intValue();
            final int intValue2 = ((Number) pair.component2()).intValue();
            arrayList.add(PlotGenericKt.xyPlot(extractXYInputOutputAxeSizes[0], extractXYInputOutputAxeSizes[1], plotFeature, new Function2<Integer, Integer, Float>() { // from class: org.jetbrains.kotlinx.dl.visualization.letsplot.PlotConv2DKt$filtersPlot$plots$1$1
                /* JADX INFO: Access modifiers changed from: package-private */
                /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                {
                    super(2);
                }

                @NotNull
                public final Float invoke(int i3, int i4) {
                    return Float.valueOf(fArr[i4][i3][intValue][intValue2]);
                }

                public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2) {
                    return invoke(((Number) obj).intValue(), ((Number) obj2).intValue());
                }
            }));
        }
        return PlotGenericKt.columnPlot(arrayList, i2, i);
    }

    public static /* synthetic */ Figure filtersPlot$default(Conv2D conv2D, PlotFeature plotFeature, int i, int i2, int i3, Object obj) {
        if ((i3 & 2) != 0) {
            plotFeature = PlotFeature.Companion.getGRAY();
        }
        if ((i3 & 4) != 0) {
            i = 64;
        }
        if ((i3 & 8) != 0) {
            i2 = 8;
        }
        return filtersPlot(conv2D, plotFeature, i, i2);
    }

    @NotNull
    public static final List<Figure> modelActivationOnLayersPlot(@NotNull TrainableModel trainableModel, @NotNull float[] fArr, @NotNull PlotFeature plotFeature, int i, int i2) {
        Intrinsics.checkNotNullParameter(trainableModel, "model");
        Intrinsics.checkNotNullParameter(fArr, "x");
        Intrinsics.checkNotNullParameter(plotFeature, "plotFeature");
        List list = (List) TrainableModel.predictAndGetActivations$default(trainableModel, fArr, (String) null, 2, (Object) null).getSecond();
        ArrayList arrayList = new ArrayList();
        for (Object obj : list) {
            float[][][][] fArr2 = obj instanceof float[][][][] ? (float[][][][]) obj : null;
            if (fArr2 != null) {
                arrayList.add(fArr2);
            }
        }
        ArrayList<float[][][][]> arrayList2 = arrayList;
        ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList2, 10));
        for (final float[][][][] fArr3 : arrayList2) {
            int[] extractXYInputOutputAxeSizes = PlotHelpersKt.extractXYInputOutputAxeSizes(fArr3, getACTIVATION_LAYERS_PERMUTATION());
            List<Pair<Integer, Integer>> cartesianProductIndices = PlotHelpersKt.cartesianProductIndices(extractXYInputOutputAxeSizes[2], extractXYInputOutputAxeSizes[3]);
            ArrayList arrayList4 = new ArrayList(CollectionsKt.collectionSizeOrDefault(cartesianProductIndices, 10));
            Iterator<T> it = cartesianProductIndices.iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                final int intValue = ((Number) pair.component1()).intValue();
                final int intValue2 = ((Number) pair.component2()).intValue();
                arrayList4.add(PlotGenericKt.xyPlot(extractXYInputOutputAxeSizes[0], extractXYInputOutputAxeSizes[1], plotFeature, new Function2<Integer, Integer, Float>() { // from class: org.jetbrains.kotlinx.dl.visualization.letsplot.PlotConv2DKt$modelActivationOnLayersPlot$1$plots$1$1
                    /* JADX INFO: Access modifiers changed from: package-private */
                    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                    {
                        super(2);
                    }

                    @NotNull
                    public final Float invoke(int i3, int i4) {
                        return Float.valueOf(fArr3[intValue][i4][i3][intValue2]);
                    }

                    public /* bridge */ /* synthetic */ Object invoke(Object obj2, Object obj3) {
                        return invoke(((Number) obj2).intValue(), ((Number) obj3).intValue());
                    }
                }));
            }
            arrayList3.add(PlotGenericKt.columnPlot(arrayList4, i2, i));
        }
        return arrayList3;
    }

    public static /* synthetic */ List modelActivationOnLayersPlot$default(TrainableModel trainableModel, float[] fArr, PlotFeature plotFeature, int i, int i2, int i3, Object obj) {
        if ((i3 & 4) != 0) {
            plotFeature = PlotFeature.Companion.getGRAY();
        }
        if ((i3 & 8) != 0) {
            i = 64;
        }
        if ((i3 & 16) != 0) {
            i2 = 8;
        }
        return modelActivationOnLayersPlot(trainableModel, fArr, plotFeature, i, i2);
    }
}
