From aa0086fbb6737e520017f550d1d62528f0cfe8f8 Mon Sep 17 00:00:00 2001 From: inter Date: Thu, 4 Sep 2025 14:09:22 +0800 Subject: [PATCH] Add File --- .../easyai/transFormer/TransWordVector.java | 302 ++++++++++++++++++ 1 file changed, 302 insertions(+) create mode 100644 src/main/java/org/dromara/easyai/transFormer/TransWordVector.java diff --git a/src/main/java/org/dromara/easyai/transFormer/TransWordVector.java b/src/main/java/org/dromara/easyai/transFormer/TransWordVector.java new file mode 100644 index 0000000..b9e3838 --- /dev/null +++ b/src/main/java/org/dromara/easyai/transFormer/TransWordVector.java @@ -0,0 +1,302 @@ +package org.dromara.easyai.transFormer; + +import org.dromara.easyai.config.TfConfig; +import org.dromara.easyai.matrixTools.Matrix; +import org.dromara.easyai.matrixTools.MatrixList; +import org.dromara.easyai.matrixTools.MatrixOperation; +import org.dromara.easyai.transFormer.model.TransWordVectorModel; + +import java.util.*; + +/** + * @author lidapeng + * @time 2025/3/29 09:11 + * @des transFormer 词向量 + */ +public class TransWordVector { + private final List wordList = new ArrayList<>();//词离散id + private final List wordVectorList = new ArrayList<>();//词向量 + private final Matrix positionCodeMatrix;//位置编码矩阵 + private final WordIds wordIds = new WordIds(); + private final String splitWord; + private final int featureDimension;//词向量维度 + private final Random random = new Random(); + private final String startWord; + private final String endWord; + private final MatrixOperation matrixOperation = new MatrixOperation(); + private final float studyRate; + private final int maxLength;//最大长度 + + public int getEndID() {//返回结束离散id,约定为2 + return 2; + } + + public int getStartID() {//返回开始离散id,约定为1 + return 1; + } + + public TransWordVector(TfConfig tfConfig) throws Exception { + this.splitWord = tfConfig.getSplitWord(); + this.studyRate = tfConfig.getStudyRate(); + this.featureDimension = tfConfig.getFeatureDimension(); + startWord = tfConfig.getStartWord(); + endWord = tfConfig.getEndWord(); + maxLength = tfConfig.getMaxLength() + 2; + positionCodeMatrix = new Matrix(maxLength, featureDimension); + wordList.add(startWord); + wordList.add(endWord); + initWordVector(); + initWordVector(); + initPositionMatrix(); + } + + private void initPositionMatrix() throws Exception {//初始化位置嵌入编码 + int x = positionCodeMatrix.getX(); + int y = positionCodeMatrix.getY(); + Random random = new Random(); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + float value = random.nextFloat(); + if (i == 0) { + value = value + 1f; + } + positionCodeMatrix.setNub(i, j, value); + } + } + } + + public TransWordVectorModel getModel() { + TransWordVectorModel transWordVectorModel = new TransWordVectorModel(); + transWordVectorModel.setWordList(wordList); + transWordVectorModel.setPositionMatrix(positionCodeMatrix.getMatrixModel()); + transWordVectorModel.setX(wordVectorList.get(0).getX()); + transWordVectorModel.setY(wordVectorList.get(0).getY()); + List wordVectorModel = new ArrayList<>(); + transWordVectorModel.setWordVectorModel(wordVectorModel); + for (Matrix matrix : wordVectorList) { + wordVectorModel.add(matrix.getMatrixModel()); + } + return transWordVectorModel; + } + + public void insertModel(TransWordVectorModel transWordVectorModel) { + int x = transWordVectorModel.getX(); + int y = transWordVectorModel.getY(); + wordList.clear(); + wordVectorList.clear(); + wordList.addAll(transWordVectorModel.getWordList()); + positionCodeMatrix.insertMatrixModel(transWordVectorModel.getPositionMatrix()); + List wordVectorModel = transWordVectorModel.getWordVectorModel(); + for (Float[] floats : wordVectorModel) { + Matrix matrix = new Matrix(x, y); + matrix.insertMatrixModel(floats); + wordVectorList.add(matrix); + } + } + + public void backEncoderError(Matrix error) throws Exception { + List ids = wordIds.getEncoder(); + int size = ids.size(); + if (size == error.getX()) { + updateWordVector(ids, error); + wordIds.getEncoder().clear(); + } else { + throw new Exception("编码器误差返回长度不一致,size:" + size + ",errorSize:" + error.getX()); + } + } + + private void updatePositionCode(Matrix error) throws Exception { + int x = error.getX(); + int y = error.getY(); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + float value = positionCodeMatrix.getNumber(i, j) + error.getNumber(i, j); + positionCodeMatrix.setNub(i, j, value); + } + } + } + + private void updateWordVector(List ids, Matrix error) throws Exception { + int size = ids.size(); + matrixOperation.mathMul(error, studyRate); + updatePositionCode(error); + for (int i = 0; i < size; i++) { + int index = ids.get(i); + Matrix wordError = error.getRow(i); + Matrix wordVector = wordVectorList.get(index); + wordVector = matrixOperation.add(wordVector, wordError); + wordVectorList.set(index, wordVector); + } + } + + public void backDecoderError(Matrix errorMatrix, Matrix allFeature) throws Exception { + Matrix error = matrixOperation.add(errorMatrix, allFeature); + List ids = wordIds.getDecoder(); + int size = ids.size(); + if (size == error.getX()) { + updateWordVector(ids, error); + wordIds.getDecoder().clear(); + } else { + throw new Exception("解码器误差返回长度不一致"); + } + } + + public String getWordByID(int id) {//通过离散id获取字符 + return wordList.get(id - 1); + } + + public int getWordID(String word) {//获取离散id + int id = -1; + int size = wordList.size(); + for (int i = 0; i < size; i++) { + if (wordList.get(i).equals(word)) { + id = i + 1; + break; + } + } + return id; + } + + public List getE(String word) {//获取期望 + List result = new ArrayList<>(); + if (splitWord == null) { + for (int i = 0; i < word.length(); i++) { + result.add(getWordID(word.substring(i, i + 1))); + } + } else { + String[] words = word.split(splitWord); + for (String s : words) { + result.add(getWordID(s)); + } + } + result.add(2); + return result; + } + + public Matrix getVector(String word) {//通过word获取对应的词向量 + int size = wordList.size(); + Matrix feature = null; + for (int i = 0; i < size; i++) { + if (wordList.get(i).equals(word)) { + feature = wordVectorList.get(i); + break; + } + } + return feature; + } + + private Matrix getVectorByStudy(String word, boolean decoder, boolean study) { + int size = wordList.size(); + Matrix feature = null; + List ids = null; + if (decoder && study) { + ids = wordIds.getDecoder(); + } else if (!decoder && study) { + ids = wordIds.getEncoder(); + } + for (int i = 0; i < size; i++) { + if (wordList.get(i).equals(word)) { + if (ids != null) { + ids.add(i); + } + feature = wordVectorList.get(i); + break; + } + } + if (feature == null) { + feature = new Matrix(1, featureDimension); + } + return feature; + } + + public Matrix getWordVector(String word, boolean decoder, boolean study) throws Exception {//获取词向量 + MatrixList matrixList = null; + if (decoder) { + if (study) { + wordIds.getDecoder().add(0); + } + matrixList = new MatrixList(wordVectorList.get(0), true, maxLength + 10); + } + if (word != null && !word.isEmpty()) { + if (word.length() > maxLength - 2) { + throw new Exception("语句长度超过设定的最大值"); + } + if (splitWord == null) { + int size = word.length(); + for (int i = 0; i < size; i++) { + Matrix feature = getVectorByStudy(word.substring(i, i + 1), decoder, study); + if (matrixList == null) { + matrixList = new MatrixList(feature, true, maxLength + 10); + } else { + matrixList.add(feature); + } + } + } else { + String[] myWord = word.split(splitWord); + for (String s : myWord) { + Matrix feature = getVectorByStudy(s, decoder, study); + if (matrixList == null) { + matrixList = new MatrixList(feature, true, maxLength + 10); + } else { + matrixList.add(feature); + } + } + } + } + return addPositionMatrix(matrixList.getMatrix()); + } + + private Matrix addPositionMatrix(Matrix matrix) throws Exception {//与位置编码相加 + int x = matrix.getX(); + int y = matrix.getY(); + Matrix positionCode = positionCodeMatrix.getSonOfMatrix(0, 0, x, y); + return matrixOperation.add(matrix, positionCode); + } + + private void initWordVector() throws Exception { + Matrix matrix = new Matrix(1, featureDimension); + for (int j = 0; j < featureDimension; j++) { + matrix.setNub(0, j, random.nextFloat()); + } + wordVectorList.add(matrix); + } + + private void insertWord(String word) throws Exception { + if (!word.equals(startWord) && !word.equals(endWord)) { + boolean here = false; + for (String myWord : wordList) { + if (myWord.equals(word)) { + here = true; + break; + } + } + if (!here) { + wordList.add(word); + initWordVector(); + } + } else { + throw new Exception("任何字词不可以与结束符或开始符重叠"); + } + } + + public void init(List sentenceList) throws Exception { + for (String sentence : sentenceList) { + if (sentence != null && !sentence.isEmpty()) { + if (splitWord == null) { + for (int i = 0; i < sentence.length(); i++) { + insertWord(sentence.substring(i, i + 1)); + } + } else { + String[] myWord = sentence.split(splitWord); + for (String s : myWord) { + insertWord(s); + } + } + } + } + } + + public int getWordSize() { + return wordList.size(); + } +}