package org.nd4j.linalg.convolution;

import java.util.Arrays;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.fft.FFT;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.ComplexNDArrayUtil;
import org.nd4j.linalg.util.Shape;

/* loaded from: input_file:org/nd4j/linalg/convolution/DefaultConvolutionInstance.class */
public class DefaultConvolutionInstance extends BaseConvolution {
    @Override // org.nd4j.linalg.convolution.ConvolutionInstance
    public IComplexNDArray convn(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, Convolution.Type type, int[] iArr) {
        if (iComplexNDArray2.isScalar() && iComplexNDArray.isScalar()) {
            return iComplexNDArray2.mul((INDArray) iComplexNDArray);
        }
        int[] ints = ArrayUtil.toInts(ArrayUtil.toNDArray(Shape.sizeForAxes(iArr, iComplexNDArray.shape())).add(ArrayUtil.toNDArray(Shape.sizeForAxes(iArr, iComplexNDArray2.shape()))).subi((Number) 1));
        IComplexNDArray rawifftn = FFT.rawifftn(FFT.rawfftn(iComplexNDArray, ints, iArr).muli((INDArray) FFT.rawfftn(iComplexNDArray2, ints, iArr)), ints, iArr);
        switch (type) {
            case FULL:
                return rawifftn;
            case SAME:
                return ComplexNDArrayUtil.center(rawifftn, iComplexNDArray.shape());
            case VALID:
                return ComplexNDArrayUtil.center(rawifftn, ArrayUtil.toInts(Transforms.abs(ArrayUtil.toNDArray(iComplexNDArray.shape()).sub(ArrayUtil.toNDArray(iComplexNDArray2.shape())).addi((Number) 1))));
            default:
                return rawifftn;
        }
    }

    @Override // org.nd4j.linalg.convolution.ConvolutionInstance
    public INDArray convn(INDArray iNDArray, INDArray iNDArray2, Convolution.Type type, int[] iArr) {
        if (iNDArray.shape().length != iNDArray2.shape().length) {
            int[] iArr2 = new int[Math.max(iNDArray.shape().length, iNDArray2.shape().length)];
            Arrays.fill(iArr2, 1);
            int abs = Math.abs(iNDArray.shape().length - iNDArray2.shape().length);
            if (iNDArray.shape().length < iNDArray2.shape().length) {
                for (int length = iNDArray2.shape().length - 1; length >= 0; length--) {
                    iArr2[length + abs] = iNDArray.shape()[length];
                }
                iNDArray = iNDArray.reshape(iArr2);
            } else if (iNDArray2.shape().length < iNDArray.shape().length) {
                for (int length2 = iNDArray2.shape().length - 1; length2 >= 0; length2--) {
                    iArr2[length2 + abs] = iNDArray2.shape()[length2];
                }
                iNDArray2 = iNDArray2.reshape(iArr2);
            }
        }
        if (iNDArray2.isScalar() && iNDArray.isScalar()) {
            return iNDArray2.mul(iNDArray);
        }
        int[] ints = ArrayUtil.toInts(ArrayUtil.toNDArray(iNDArray.shape()).add(ArrayUtil.toNDArray(iNDArray2.shape())).subi((Number) 1));
        IComplexNDArray rawfftn = FFT.rawfftn(Nd4j.createComplex(iNDArray), ints, iArr);
        IComplexNDArray rawfftn2 = FFT.rawfftn(Nd4j.createComplex(iNDArray2), ints, iArr);
        if (!Arrays.equals(rawfftn.shape(), rawfftn2.shape())) {
            if (rawfftn.length() < rawfftn2.length()) {
                rawfftn = ComplexNDArrayUtil.padWithZeros(rawfftn, rawfftn2.shape());
            } else {
                rawfftn2 = ComplexNDArrayUtil.padWithZeros(rawfftn2, rawfftn.shape());
            }
        }
        IComplexNDArray ifftn = FFT.ifftn(rawfftn.muli((INDArray) rawfftn2));
        switch (type) {
            case FULL:
                return ifftn.getReal();
            case SAME:
                return ComplexNDArrayUtil.center(ifftn, iNDArray.shape()).getReal();
            case VALID:
                return ComplexNDArrayUtil.center(ifftn, ArrayUtil.toInts(Transforms.abs(ArrayUtil.toNDArray(iNDArray.shape()).sub(ArrayUtil.toNDArray(iNDArray2.shape())).addi((Number) 1)))).getReal();
            default:
                return ifftn.getReal();
        }
    }
}
