package org.nd4j.linalg.api.blas.impl;

import org.nd4j.linalg.api.blas.Level3;
import org.nd4j.linalg.api.blas.params.GemmParams;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/blas/impl/BaseLevel3.class */
public abstract class BaseLevel3 extends BaseLevel implements Level3 {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseLevel3.class);

    @Override // org.nd4j.linalg.api.blas.Level3
    public void gemm(char c, char c2, char c3, double d, INDArray iNDArray, INDArray iNDArray2, double d2, INDArray iNDArray3) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(true, iNDArray, iNDArray2, iNDArray3);
        }
        GemmParams gemmParams = new GemmParams(iNDArray, iNDArray2, iNDArray3);
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, gemmParams.getA(), gemmParams.getB(), gemmParams.getC());
            dgemm(c, gemmParams.getTransA(), gemmParams.getTransB(), gemmParams.getM(), gemmParams.getN(), gemmParams.getK(), 1.0d, gemmParams.getA(), gemmParams.getLda(), gemmParams.getB(), gemmParams.getLdb(), 0.0d, iNDArray3, gemmParams.getLdc());
        } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, gemmParams.getA(), gemmParams.getB(), gemmParams.getC());
            sgemm(c, gemmParams.getTransA(), gemmParams.getTransB(), gemmParams.getM(), gemmParams.getN(), gemmParams.getK(), 1.0f, gemmParams.getA(), gemmParams.getLda(), gemmParams.getB(), gemmParams.getLdb(), 0.0f, iNDArray3, gemmParams.getLdc());
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, gemmParams.getA(), gemmParams.getB(), gemmParams.getC());
            hgemm(c, gemmParams.getTransA(), gemmParams.getTransB(), gemmParams.getM(), gemmParams.getN(), gemmParams.getK(), 1.0f, gemmParams.getA(), gemmParams.getLda(), gemmParams.getB(), gemmParams.getLdb(), 0.0f, iNDArray3, gemmParams.getLdc());
        }
        OpExecutionerUtil.checkForAny(iNDArray3);
    }

    @Override // org.nd4j.linalg.api.blas.Level3
    public void gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, boolean z2, double d, double d2) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(true, iNDArray, iNDArray2, iNDArray3);
        }
        GemmParams gemmParams = new GemmParams(iNDArray, iNDArray2, iNDArray3, z, z2);
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, gemmParams.getA(), gemmParams.getB(), iNDArray3);
            dgemm(iNDArray.ordering(), gemmParams.getTransA(), gemmParams.getTransB(), gemmParams.getM(), gemmParams.getN(), gemmParams.getK(), d, gemmParams.getA(), gemmParams.getLda(), gemmParams.getB(), gemmParams.getLdb(), d2, iNDArray3, gemmParams.getLdc());
        } else if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, gemmParams.getA(), gemmParams.getB(), iNDArray3);
            sgemm(iNDArray.ordering(), gemmParams.getTransA(), gemmParams.getTransB(), gemmParams.getM(), gemmParams.getN(), gemmParams.getK(), (float) d, gemmParams.getA(), gemmParams.getLda(), gemmParams.getB(), gemmParams.getLdb(), (float) d2, iNDArray3, gemmParams.getLdc());
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, gemmParams.getA(), gemmParams.getB(), iNDArray3);
            hgemm(iNDArray.ordering(), gemmParams.getTransA(), gemmParams.getTransB(), gemmParams.getM(), gemmParams.getN(), gemmParams.getK(), (float) d, gemmParams.getA(), gemmParams.getLda(), gemmParams.getB(), gemmParams.getLdb(), (float) d2, iNDArray3, gemmParams.getLdc());
        }
        OpExecutionerUtil.checkForAny(iNDArray3);
    }

    @Override // org.nd4j.linalg.api.blas.Level3
    public void symm(char c, char c2, char c3, double d, INDArray iNDArray, INDArray iNDArray2, double d2, INDArray iNDArray3) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, iNDArray, iNDArray2, iNDArray3);
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, iNDArray, iNDArray2, iNDArray3);
            dsymm(c, c2, c3, iNDArray3.rows(), iNDArray3.columns(), d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0), d2, iNDArray3, (int) iNDArray3.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, iNDArray, iNDArray2, iNDArray3);
            ssymm(c, c2, c3, iNDArray3.rows(), iNDArray3.columns(), (float) d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0), (float) d2, iNDArray3, (int) iNDArray3.size(0));
        }
        OpExecutionerUtil.checkForAny(iNDArray3);
    }

    @Override // org.nd4j.linalg.api.blas.Level3
    public void syrk(char c, char c2, char c3, double d, INDArray iNDArray, double d2, INDArray iNDArray2) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, iNDArray, iNDArray2);
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, iNDArray, iNDArray2);
            dsyrk(c, c2, c3, iNDArray2.rows(), 1, d, iNDArray, (int) iNDArray.size(0), d2, iNDArray2, (int) iNDArray2.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, iNDArray, iNDArray2);
            ssyrk(c, c2, c3, iNDArray2.rows(), 1, (float) d, iNDArray, (int) iNDArray.size(0), (float) d2, iNDArray2, (int) iNDArray2.size(0));
        }
        OpExecutionerUtil.checkForAny(iNDArray2);
    }

    @Override // org.nd4j.linalg.api.blas.Level3
    public void syr2k(char c, char c2, char c3, double d, INDArray iNDArray, INDArray iNDArray2, double d2, INDArray iNDArray3) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, iNDArray, iNDArray2, iNDArray3);
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, iNDArray, iNDArray2, iNDArray3);
            dsyr2k(c, c2, c3, iNDArray.rows(), iNDArray.columns(), d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0), d2, iNDArray3, (int) iNDArray3.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, iNDArray, iNDArray2, iNDArray3);
            ssyr2k(c, c2, c3, iNDArray.rows(), iNDArray.columns(), (float) d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0), (float) d2, iNDArray3, (int) iNDArray3.size(0));
        }
        OpExecutionerUtil.checkForAny(iNDArray3);
    }

    @Override // org.nd4j.linalg.api.blas.Level3
    public void trmm(char c, char c2, char c3, char c4, char c5, double d, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, iNDArray, iNDArray2, iNDArray3);
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, iNDArray, iNDArray2, iNDArray3);
            dtrmm(c, c2, c3, c4, c5, iNDArray.rows(), iNDArray.columns(), d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, iNDArray, iNDArray2, iNDArray3);
            strmm(c, c2, c3, c4, c5, iNDArray.rows(), iNDArray.columns(), (float) d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0));
        }
        OpExecutionerUtil.checkForAny(iNDArray3);
    }

    @Override // org.nd4j.linalg.api.blas.Level3
    public void trsm(char c, char c2, char c3, char c4, char c5, double d, INDArray iNDArray, INDArray iNDArray2) {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL) {
            OpProfiler.getInstance().processBlasCall(false, iNDArray, iNDArray2);
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, iNDArray, iNDArray2);
            dtrsm(c, c2, c3, c4, c5, iNDArray.rows(), iNDArray.columns(), d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0));
        } else {
            DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, iNDArray, iNDArray2);
            strsm(c, c2, c3, c4, c5, iNDArray.rows(), iNDArray.columns(), (float) d, iNDArray, (int) iNDArray.size(0), iNDArray2, (int) iNDArray2.size(0));
        }
        OpExecutionerUtil.checkForAny(iNDArray2);
    }

    protected abstract void hgemm(char c, char c2, char c3, int i, int i2, int i3, float f, INDArray iNDArray, int i4, INDArray iNDArray2, int i5, float f2, INDArray iNDArray3, int i6);

    protected abstract void sgemm(char c, char c2, char c3, int i, int i2, int i3, float f, INDArray iNDArray, int i4, INDArray iNDArray2, int i5, float f2, INDArray iNDArray3, int i6);

    protected abstract void ssymm(char c, char c2, char c3, int i, int i2, float f, INDArray iNDArray, int i3, INDArray iNDArray2, int i4, float f2, INDArray iNDArray3, int i5);

    protected abstract void ssyrk(char c, char c2, char c3, int i, int i2, float f, INDArray iNDArray, int i3, float f2, INDArray iNDArray2, int i4);

    protected abstract void ssyr2k(char c, char c2, char c3, int i, int i2, float f, INDArray iNDArray, int i3, INDArray iNDArray2, int i4, float f2, INDArray iNDArray3, int i5);

    protected abstract void strmm(char c, char c2, char c3, char c4, char c5, int i, int i2, float f, INDArray iNDArray, int i3, INDArray iNDArray2, int i4);

    protected abstract void strsm(char c, char c2, char c3, char c4, char c5, int i, int i2, float f, INDArray iNDArray, int i3, INDArray iNDArray2, int i4);

    protected abstract void dgemm(char c, char c2, char c3, int i, int i2, int i3, double d, INDArray iNDArray, int i4, INDArray iNDArray2, int i5, double d2, INDArray iNDArray3, int i6);

    protected abstract void dsymm(char c, char c2, char c3, int i, int i2, double d, INDArray iNDArray, int i3, INDArray iNDArray2, int i4, double d2, INDArray iNDArray3, int i5);

    protected abstract void dsyrk(char c, char c2, char c3, int i, int i2, double d, INDArray iNDArray, int i3, double d2, INDArray iNDArray2, int i4);

    protected abstract void dsyr2k(char c, char c2, char c3, int i, int i2, double d, INDArray iNDArray, int i3, INDArray iNDArray2, int i4, double d2, INDArray iNDArray3, int i5);

    protected abstract void dtrmm(char c, char c2, char c3, char c4, char c5, int i, int i2, double d, INDArray iNDArray, int i3, INDArray iNDArray2, int i4);

    protected abstract void dtrsm(char c, char c2, char c3, char c4, char c5, int i, int i2, double d, INDArray iNDArray, int i3, INDArray iNDArray2, int i4);
}
