diff --git a/src/main/java/org/dromara/easyai/matrixTools/CudaMatrix.java b/src/main/java/org/dromara/easyai/matrixTools/CudaMatrix.java new file mode 100644 index 0000000..6ede76b --- /dev/null +++ b/src/main/java/org/dromara/easyai/matrixTools/CudaMatrix.java @@ -0,0 +1,22 @@ +package org.dromara.easyai.matrixTools; + +public interface CudaMatrix { + void init() throws Exception; + void softMax(Matrix matrix) throws Exception; + + Matrix matrixSoftMaxPd(Matrix qkt, Matrix errorMatrix, float wordVectorDimension) throws Exception; + + Matrix mulMatrix(Matrix matrix1, Matrix matrix2) throws Exception; + + // 矩阵数加 + void mathAdd(Matrix matrix, float nub) throws Exception; + + // 矩阵数减 + void mathSub(Matrix matrix, float nub) throws Exception; + + // 矩阵数乘 + void mathMul(Matrix matrix, float nub) throws Exception; + + // 矩阵数除 + void mathDiv(Matrix matrix, float nub) throws Exception; +}