package org.nd4j.linalg.fft;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.VectorFFT;
import org.nd4j.linalg.api.ops.impl.transforms.VectorIFFT;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.ComplexNDArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/fft/DefaultFFTInstance.class */
public class DefaultFFTInstance extends BaseFFTInstance {
    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray fft(INDArray iNDArray, int i, int i2) {
        IComplexNDArray createComplex = Nd4j.createComplex(iNDArray);
        if (createComplex.isVector()) {
            return (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorFFT(createComplex, i));
        }
        int[] replace = ArrayUtil.replace(iNDArray.shape(), i2, i);
        ArrayUtil.range(0, replace.length);
        IComplexNDArray dup = Nd4j.createComplex(iNDArray).dup();
        int size = dup.size(i2);
        if (i > size) {
            dup = ComplexNDArrayUtil.padWithZeros(dup, replace);
        } else if (i < size) {
            dup = ComplexNDArrayUtil.truncate(dup, i, i2);
        }
        return rawfft(dup, i, i2);
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray fft(IComplexNDArray iComplexNDArray, int i, int i2) {
        return iComplexNDArray.isVector() ? (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorFFT(iComplexNDArray, i)) : rawfft(iComplexNDArray, i, i2);
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray ifft(INDArray iNDArray, int i, int i2) {
        IComplexNDArray createComplex = Nd4j.createComplex(iNDArray);
        return createComplex.isVector() ? (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorIFFT(createComplex, i)) : rawifft(createComplex, i, i2);
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray ifft(IComplexNDArray iComplexNDArray, int i, int i2) {
        return iComplexNDArray.isVector() ? (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorIFFT(iComplexNDArray, i)) : rawifft(iComplexNDArray, i, i2);
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray ifft(INDArray iNDArray, int i) {
        IComplexNDArray createComplex = Nd4j.createComplex(iNDArray);
        return createComplex.isVector() ? (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorIFFT(createComplex, i)) : rawifft(createComplex, i, createComplex.shape().length - 1);
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray ifft(IComplexNDArray iComplexNDArray) {
        return iComplexNDArray.isVector() ? (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorIFFT(iComplexNDArray, iComplexNDArray.length())) : rawifft(iComplexNDArray, iComplexNDArray.size(iComplexNDArray.shape().length - 1), iComplexNDArray.shape().length - 1);
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray rawfft(IComplexNDArray iComplexNDArray, int i, int i2) {
        IComplexNDArray dup = iComplexNDArray.dup();
        if (iComplexNDArray.size(i2) != i) {
            int[] copy = ArrayUtil.copy(dup.shape());
            copy[i2] = i;
            dup = iComplexNDArray.size(i2) > i ? ComplexNDArrayUtil.truncate(dup, i, i2) : ComplexNDArrayUtil.padWithZeros(dup, copy);
        }
        if (i2 != dup.shape().length - 1) {
            dup = dup.swapAxes(dup.shape().length - 1, i2);
        }
        Nd4j.getExecutioner().iterateOverAllRows(new VectorFFT(dup, i));
        if (i2 != dup.shape().length - 1) {
            dup = dup.swapAxes(dup.shape().length - 1, i2);
        }
        return dup;
    }

    @Override // org.nd4j.linalg.fft.FFTInstance
    public IComplexNDArray rawifft(IComplexNDArray iComplexNDArray, int i, int i2) {
        IComplexNDArray dup = iComplexNDArray.dup();
        if (iComplexNDArray.size(i2) != i) {
            int[] copy = ArrayUtil.copy(dup.shape());
            copy[i2] = i;
            dup = iComplexNDArray.size(i2) > i ? ComplexNDArrayUtil.truncate(dup, i, i2) : ComplexNDArrayUtil.padWithZeros(dup, copy);
        }
        if (i2 != dup.shape().length - 1) {
            dup = dup.swapAxes(dup.shape().length - 1, i2);
        }
        Nd4j.getExecutioner().iterateOverAllRows(new VectorIFFT(dup, i));
        if (i2 != dup.shape().length - 1) {
            dup = dup.swapAxes(dup.shape().length - 1, i2);
        }
        return dup;
    }
}
