This commit is contained in:
2025-09-04 14:09:22 +08:00
parent 3e9fb19d66
commit aa0086fbb6

View File

@@ -0,0 +1,302 @@
package org.dromara.easyai.transFormer;
import org.dromara.easyai.config.TfConfig;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixList;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.model.TransWordVectorModel;
import java.util.*;
/**
* @author lidapeng
* @time 2025/3/29 09:11
* @des transFormer 词向量
*/
public class TransWordVector {
private final List<String> wordList = new ArrayList<>();//词离散id
private final List<Matrix> wordVectorList = new ArrayList<>();//词向量
private final Matrix positionCodeMatrix;//位置编码矩阵
private final WordIds wordIds = new WordIds();
private final String splitWord;
private final int featureDimension;//词向量维度
private final Random random = new Random();
private final String startWord;
private final String endWord;
private final MatrixOperation matrixOperation = new MatrixOperation();
private final float studyRate;
private final int maxLength;//最大长度
public int getEndID() {//返回结束离散id约定为2
return 2;
}
public int getStartID() {//返回开始离散id约定为1
return 1;
}
public TransWordVector(TfConfig tfConfig) throws Exception {
this.splitWord = tfConfig.getSplitWord();
this.studyRate = tfConfig.getStudyRate();
this.featureDimension = tfConfig.getFeatureDimension();
startWord = tfConfig.getStartWord();
endWord = tfConfig.getEndWord();
maxLength = tfConfig.getMaxLength() + 2;
positionCodeMatrix = new Matrix(maxLength, featureDimension);
wordList.add(startWord);
wordList.add(endWord);
initWordVector();
initWordVector();
initPositionMatrix();
}
private void initPositionMatrix() throws Exception {//初始化位置嵌入编码
int x = positionCodeMatrix.getX();
int y = positionCodeMatrix.getY();
Random random = new Random();
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
float value = random.nextFloat();
if (i == 0) {
value = value + 1f;
}
positionCodeMatrix.setNub(i, j, value);
}
}
}
public TransWordVectorModel getModel() {
TransWordVectorModel transWordVectorModel = new TransWordVectorModel();
transWordVectorModel.setWordList(wordList);
transWordVectorModel.setPositionMatrix(positionCodeMatrix.getMatrixModel());
transWordVectorModel.setX(wordVectorList.get(0).getX());
transWordVectorModel.setY(wordVectorList.get(0).getY());
List<Float[]> wordVectorModel = new ArrayList<>();
transWordVectorModel.setWordVectorModel(wordVectorModel);
for (Matrix matrix : wordVectorList) {
wordVectorModel.add(matrix.getMatrixModel());
}
return transWordVectorModel;
}
public void insertModel(TransWordVectorModel transWordVectorModel) {
int x = transWordVectorModel.getX();
int y = transWordVectorModel.getY();
wordList.clear();
wordVectorList.clear();
wordList.addAll(transWordVectorModel.getWordList());
positionCodeMatrix.insertMatrixModel(transWordVectorModel.getPositionMatrix());
List<Float[]> wordVectorModel = transWordVectorModel.getWordVectorModel();
for (Float[] floats : wordVectorModel) {
Matrix matrix = new Matrix(x, y);
matrix.insertMatrixModel(floats);
wordVectorList.add(matrix);
}
}
public void backEncoderError(Matrix error) throws Exception {
List<Integer> ids = wordIds.getEncoder();
int size = ids.size();
if (size == error.getX()) {
updateWordVector(ids, error);
wordIds.getEncoder().clear();
} else {
throw new Exception("编码器误差返回长度不一致,size:" + size + ",errorSize:" + error.getX());
}
}
private void updatePositionCode(Matrix error) throws Exception {
int x = error.getX();
int y = error.getY();
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
float value = positionCodeMatrix.getNumber(i, j) + error.getNumber(i, j);
positionCodeMatrix.setNub(i, j, value);
}
}
}
private void updateWordVector(List<Integer> ids, Matrix error) throws Exception {
int size = ids.size();
matrixOperation.mathMul(error, studyRate);
updatePositionCode(error);
for (int i = 0; i < size; i++) {
int index = ids.get(i);
Matrix wordError = error.getRow(i);
Matrix wordVector = wordVectorList.get(index);
wordVector = matrixOperation.add(wordVector, wordError);
wordVectorList.set(index, wordVector);
}
}
public void backDecoderError(Matrix errorMatrix, Matrix allFeature) throws Exception {
Matrix error = matrixOperation.add(errorMatrix, allFeature);
List<Integer> ids = wordIds.getDecoder();
int size = ids.size();
if (size == error.getX()) {
updateWordVector(ids, error);
wordIds.getDecoder().clear();
} else {
throw new Exception("解码器误差返回长度不一致");
}
}
public String getWordByID(int id) {//通过离散id获取字符
return wordList.get(id - 1);
}
public int getWordID(String word) {//获取离散id
int id = -1;
int size = wordList.size();
for (int i = 0; i < size; i++) {
if (wordList.get(i).equals(word)) {
id = i + 1;
break;
}
}
return id;
}
public List<Integer> getE(String word) {//获取期望
List<Integer> result = new ArrayList<>();
if (splitWord == null) {
for (int i = 0; i < word.length(); i++) {
result.add(getWordID(word.substring(i, i + 1)));
}
} else {
String[] words = word.split(splitWord);
for (String s : words) {
result.add(getWordID(s));
}
}
result.add(2);
return result;
}
public Matrix getVector(String word) {//通过word获取对应的词向量
int size = wordList.size();
Matrix feature = null;
for (int i = 0; i < size; i++) {
if (wordList.get(i).equals(word)) {
feature = wordVectorList.get(i);
break;
}
}
return feature;
}
private Matrix getVectorByStudy(String word, boolean decoder, boolean study) {
int size = wordList.size();
Matrix feature = null;
List<Integer> ids = null;
if (decoder && study) {
ids = wordIds.getDecoder();
} else if (!decoder && study) {
ids = wordIds.getEncoder();
}
for (int i = 0; i < size; i++) {
if (wordList.get(i).equals(word)) {
if (ids != null) {
ids.add(i);
}
feature = wordVectorList.get(i);
break;
}
}
if (feature == null) {
feature = new Matrix(1, featureDimension);
}
return feature;
}
public Matrix getWordVector(String word, boolean decoder, boolean study) throws Exception {//获取词向量
MatrixList matrixList = null;
if (decoder) {
if (study) {
wordIds.getDecoder().add(0);
}
matrixList = new MatrixList(wordVectorList.get(0), true, maxLength + 10);
}
if (word != null && !word.isEmpty()) {
if (word.length() > maxLength - 2) {
throw new Exception("语句长度超过设定的最大值");
}
if (splitWord == null) {
int size = word.length();
for (int i = 0; i < size; i++) {
Matrix feature = getVectorByStudy(word.substring(i, i + 1), decoder, study);
if (matrixList == null) {
matrixList = new MatrixList(feature, true, maxLength + 10);
} else {
matrixList.add(feature);
}
}
} else {
String[] myWord = word.split(splitWord);
for (String s : myWord) {
Matrix feature = getVectorByStudy(s, decoder, study);
if (matrixList == null) {
matrixList = new MatrixList(feature, true, maxLength + 10);
} else {
matrixList.add(feature);
}
}
}
}
return addPositionMatrix(matrixList.getMatrix());
}
private Matrix addPositionMatrix(Matrix matrix) throws Exception {//与位置编码相加
int x = matrix.getX();
int y = matrix.getY();
Matrix positionCode = positionCodeMatrix.getSonOfMatrix(0, 0, x, y);
return matrixOperation.add(matrix, positionCode);
}
private void initWordVector() throws Exception {
Matrix matrix = new Matrix(1, featureDimension);
for (int j = 0; j < featureDimension; j++) {
matrix.setNub(0, j, random.nextFloat());
}
wordVectorList.add(matrix);
}
private void insertWord(String word) throws Exception {
if (!word.equals(startWord) && !word.equals(endWord)) {
boolean here = false;
for (String myWord : wordList) {
if (myWord.equals(word)) {
here = true;
break;
}
}
if (!here) {
wordList.add(word);
initWordVector();
}
} else {
throw new Exception("任何字词不可以与结束符或开始符重叠");
}
}
public void init(List<String> sentenceList) throws Exception {
for (String sentence : sentenceList) {
if (sentence != null && !sentence.isEmpty()) {
if (splitWord == null) {
for (int i = 0; i < sentence.length(); i++) {
insertWord(sentence.substring(i, i + 1));
}
} else {
String[] myWord = sentence.split(splitWord);
for (String s : myWord) {
insertWord(s);
}
}
}
}
}
public int getWordSize() {
return wordList.size();
}
}