package org.nd4j.linalg.api.ops;

import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.exception.IllegalOpException;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Max;
import org.nd4j.linalg.api.ops.impl.accum.Mean;
import org.nd4j.linalg.api.ops.impl.accum.Min;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.api.ops.impl.accum.NormMax;
import org.nd4j.linalg.api.ops.impl.accum.Prod;
import org.nd4j.linalg.api.ops.impl.accum.Sum;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.transforms.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.Log;
import org.nd4j.linalg.api.ops.impl.transforms.Pow;
import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/api/ops/OpExecutionerTests.class */
public abstract class OpExecutionerTests {
    @After
    public void after() {
        Nd4j.factory().setOrder('f');
    }

    @Test
    public void testCosineSimilarity() {
        Assert.assertEquals(1.0d, Transforms.cosineSim(Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f}), Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f})), 0.1d);
    }

    @Test
    public void testEuclideanDistance() {
        Assert.assertEquals(7.0710678118654755d, Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(Nd4j.create(new double[]{55.0d, 55.0d}), Nd4j.create(new double[]{60.0d, 60.0d}))).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testScalarMaxOp() {
        INDArray negi = Nd4j.linspace(1, 6, 6).negi();
        INDArray ones = Nd4j.ones(6);
        Nd4j.getExecutioner().exec(new ScalarMax(negi, 1));
        Assert.assertEquals(negi, ones);
    }

    @Test
    public void testSetRange() {
        INDArray linspace = Nd4j.linspace(1, 4, 4);
        Nd4j.getExecutioner().exec(new SetRange(linspace, 0.0d, 1.0d));
        for (int i = 0; i < linspace.length(); i++) {
            double d = linspace.getDouble(i);
            Assert.assertTrue(d >= 0.0d && d <= 1.0d);
        }
        INDArray linspace2 = Nd4j.linspace(1, 4, 4);
        Nd4j.getExecutioner().exec(new SetRange(linspace2, 2.0d, 4.0d));
        for (int i2 = 0; i2 < linspace2.length(); i2++) {
            double d2 = linspace2.getDouble(i2);
            Assert.assertTrue(d2 >= 2.0d && d2 <= 4.0d);
        }
    }

    @Test
    public void testNormMax() {
        Assert.assertEquals(10.0d, Nd4j.getExecutioner().execAndReturn(new NormMax(Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f}))).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testNorm2() {
        Assert.assertEquals(5.477225575051661d, Nd4j.getExecutioner().execAndReturn(new Norm2(Nd4j.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f}))).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testAdd() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray ones = Nd4j.ones(5);
        INDArray dup = ones.dup();
        INDArray valueArrayOf = Nd4j.valueArrayOf(5, 2.0d);
        executioner.exec(new AddOp(ones, dup, ones));
        Assert.assertEquals(valueArrayOf, ones);
    }

    @Test
    public void testMul() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray ones = Nd4j.ones(5);
        INDArray dup = ones.dup();
        INDArray valueArrayOf = Nd4j.valueArrayOf(5, 1.0d);
        executioner.exec(new MulOp(ones, dup, ones));
        Assert.assertEquals(valueArrayOf, ones);
    }

    @Test
    public void testExecutioner() throws IllegalOpException {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray ones = Nd4j.ones(5);
        INDArray dup = ones.dup();
        INDArray valueArrayOf = Nd4j.valueArrayOf(5, 2.0d);
        executioner.exec(new AddOp(ones, dup, ones));
        Assert.assertEquals(valueArrayOf, ones);
        Sum sum = new Sum(ones.dup());
        executioner.exec(sum);
        Assert.assertEquals(10.0d, sum.currentResult().doubleValue(), 0.1d);
        Prod prod = new Prod(ones.dup());
        executioner.exec(prod);
        Assert.assertEquals(32.0d, prod.currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testMaxMin() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray linspace = Nd4j.linspace(1, 5, 5);
        Max max = new Max(linspace);
        executioner.exec(max);
        Assert.assertEquals(5.0d, max.currentResult().doubleValue(), 0.1d);
        Assert.assertEquals(1.0d, new Min(linspace).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testProd() {
        Assert.assertEquals(720.0d, Nd4j.getExecutioner().execAndReturn(new Prod(Nd4j.linspace(1, 6, 6))).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testSum() {
        Assert.assertEquals(21.0d, Nd4j.getExecutioner().execAndReturn(new Sum(Nd4j.linspace(1, 6, 6))).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testDescriptiveStatsDouble() {
        Nd4j.dtype = DataBuffer.Type.DOUBLE;
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray linspace = Nd4j.linspace(1, 5, 5);
        Mean mean = new Mean(linspace);
        executioner.exec(mean);
        Assert.assertEquals(3.0d, mean.currentResult().doubleValue(), 0.1d);
        Variance variance = new Variance(linspace.dup(), true);
        executioner.exec(variance);
        Assert.assertEquals(2.5d, variance.currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testDescriptiveStats() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray linspace = Nd4j.linspace(1, 5, 5);
        Mean mean = new Mean(linspace);
        executioner.exec(mean);
        Assert.assertEquals(3.0d, mean.currentResult().doubleValue(), 0.1d);
        Variance variance = new Variance(linspace.dup(), true);
        executioner.exec(variance);
        Assert.assertEquals(2.5d, variance.currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testRowSoftmax() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        SoftMax softMax = new SoftMax(Nd4j.linspace(1, 6, 6));
        executioner.exec(softMax);
        Assert.assertEquals(1.0d, softMax.z().sum(Integer.MAX_VALUE).getDouble(0), 0.1d);
    }

    @Test
    public void testPow() {
        Pow pow = new Pow(Nd4j.linspace(1, 6, 6), 2.0d);
        Nd4j.getExecutioner().exec(pow);
        Assert.assertEquals(Nd4j.create(new float[]{1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}), pow.z());
    }

    @Test
    public void testComparisonOps() {
        INDArray linspace = Nd4j.linspace(1, 6, 6);
        INDArray ones = Nd4j.ones(6);
        INDArray zeros = Nd4j.zeros(6);
        Assert.assertEquals(ones, Nd4j.getExecutioner().execAndReturn(new ScalarGreaterThan(linspace, 0)));
        Assert.assertEquals(zeros, Nd4j.getExecutioner().execAndReturn(new ScalarGreaterThan(linspace, 7)));
        Assert.assertEquals(zeros, Nd4j.getExecutioner().execAndReturn(new ScalarLessThan(linspace, 0)));
        Assert.assertEquals(ones, Nd4j.getExecutioner().execAndReturn(new ScalarLessThan(linspace, 7)));
    }

    @Test
    public void testScalarArithmetic() {
        INDArray linspace = Nd4j.linspace(1, 6, 6);
        INDArray linspace2 = Nd4j.linspace(2, 7, 6);
        Nd4j.getExecutioner().exec(new ScalarAdd(linspace, 1));
        Assert.assertEquals(linspace2, linspace);
    }

    @Test
    public void testDimensionMax() {
        INDArray slice = Nd4j.linspace(1, 6, 6).reshape(2, 3).slice(0);
        Assert.assertEquals(5.0d, Nd4j.getExecutioner().execAndReturn(new Max(slice)).currentResult().doubleValue(), 0.1d);
        Assert.assertEquals(1.0d, Nd4j.getExecutioner().execAndReturn(new Min(slice)).currentResult().doubleValue(), 0.1d);
    }

    @Test
    public void testStridedLog() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray slice = Nd4j.linspace(1, 6, 6).reshape(2, 3).slice(0);
        executioner.exec(new Log(slice));
        Assert.assertEquals(Nd4j.create(new FloatBuffer(new float[]{0.0f, 1.0986123f, 1.609438f})), slice);
    }

    @Test
    public void testStridedExp() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        INDArray slice = Nd4j.linspace(1, 6, 6).reshape(2, 3).slice(0);
        executioner.exec(new Exp(slice));
        Assert.assertEquals(Nd4j.create(new FloatBuffer(new float[]{2.7182817f, 20.085537f, 148.41316f})), slice);
    }

    @Test
    public void testSoftMax() {
        OpExecutioner executioner = Nd4j.getExecutioner();
        SoftMax softMax = new SoftMax(Nd4j.linspace(1, 6, 6));
        executioner.exec(softMax);
        Assert.assertEquals(1.0d, softMax.z().sum(Integer.MAX_VALUE).getDouble(0), 0.1d);
    }

    @Test
    public void testDimensionSoftMax() {
        Nd4j.factory().setOrder('c');
        INDArray reshape = Nd4j.linspace(1, 6, 6).reshape(2, 3);
        Nd4j.getExecutioner().exec(new SoftMax(reshape), 1);
        Assert.assertEquals(reshape.getRow(0).sum(Integer.MAX_VALUE).getDouble(0), 1.0d, 0.1d);
    }
}
