package org.nd4j.linalg.convolution.test;

import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/convolution/test/ConvolutionTests.class */
public abstract class ConvolutionTests {
    @Test
    public void convNTest() {
        Nd4j.EPS_THRESHOLD = 0.1d;
        Assert.assertEquals(Nd4j.create(new double[]{1.0000012d}), Convolution.convn(Nd4j.linspace(1, 8, 8), Nd4j.linspace(1, 3, 3), Convolution.Type.VALID));
    }

    @Test
    public void testConv2d() {
        Assert.assertEquals(Nd4j.create(new double[]{56.0d, 98.0d}), Convolution.convn(Nd4j.linspace(1, 8, 8).reshape(2, 4), Nd4j.linspace(1, 6, 6).reshape(2, 3), Convolution.Type.VALID));
    }
}
