diff --git a/src/main/java/org/dromara/easyai/naturalLanguage/word/WordEmbedding.java b/src/main/java/org/dromara/easyai/naturalLanguage/word/WordEmbedding.java new file mode 100644 index 0000000..e7be89a --- /dev/null +++ b/src/main/java/org/dromara/easyai/naturalLanguage/word/WordEmbedding.java @@ -0,0 +1,175 @@ +package org.dromara.easyai.naturalLanguage.word; + + +import org.dromara.easyai.matrixTools.Matrix; +import org.dromara.easyai.matrixTools.MatrixList; +import org.dromara.easyai.matrixTools.MatrixOperation; +import org.dromara.easyai.config.RZ; +import org.dromara.easyai.config.SentenceConfig; +import org.dromara.easyai.entity.SentenceModel; +import org.dromara.easyai.entity.WordMatrix; +import org.dromara.easyai.entity.WordTwoVectorModel; +import org.dromara.easyai.function.Tanh; +import org.dromara.easyai.i.OutBack; +import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature; +import org.dromara.easyai.rnnNerveCenter.NerveManager; +import org.dromara.easyai.rnnNerveEntity.SensoryNerve; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 词嵌入向量训练 + */ +public class WordEmbedding extends MatrixOperation { + private NerveManager nerveManager; + private SentenceModel sentenceModel; + private final List wordList = new ArrayList<>();//单字集合 + private SentenceConfig config; + private int wordVectorDimension; + private int studyTimes = 1; + + public void setStudyTimes(int studyTimes) { + this.studyTimes = studyTimes; + } + + public void setConfig(SentenceConfig config) { + this.config = config; + } + + public int getWordVectorDimension() { + return wordVectorDimension; + } + + public void init(SentenceModel sentenceModel, int wordVectorDimension) throws Exception { + this.wordVectorDimension = wordVectorDimension; + this.sentenceModel = sentenceModel; + wordList.addAll(sentenceModel.getWordSet()); + nerveManager = new NerveManager(wordList.size(), wordVectorDimension, wordList.size() + , 1, new Tanh(), config.getWeStudyPoint(), config.getRzModel(), + config.getWeLParam()); + nerveManager.init(true, false, true); + } + + public List getWordList() { + return wordList; + } + + public String getWord(int id) { + return wordList.get(id); + } + + public void insertModel(WordTwoVectorModel wordTwoVectorModel, int wordVectorDimension) throws Exception { + wordList.clear(); + this.wordVectorDimension = wordVectorDimension; + List myWordList = wordTwoVectorModel.getWordList(); + wordList.addAll(myWordList); + nerveManager = new NerveManager(wordList.size(), wordVectorDimension, wordList.size() + , 1, new Tanh(), config.getWeStudyPoint(), RZ.NOT_RZ, 0); + nerveManager.init(true, false, true); + nerveManager.insertModelParameter(wordTwoVectorModel.getModelParameter()); + } + + public MyWordFeature getEmbedding(String word, long eventId, boolean once) throws Exception {//做截断 + MyWordFeature myWordFeature = new MyWordFeature(); + int wordDim = wordVectorDimension;// + MatrixList matrixList = null; + for (int i = 0; i < word.length(); i++) { + WordMatrix wordMatrix = new WordMatrix(wordDim); + String myWord; + if (!once) { + myWord = word.substring(i, i + 1); + } else { + myWord = word; + } + int index = getID(myWord); + studyDNN(eventId, index, 0, wordMatrix, false); + if (matrixList == null) { + myWordFeature.setFirstFeatureList(wordMatrix.getList()); + matrixList = new MatrixList(wordMatrix.getVector(), true); + } else { + matrixList.add(wordMatrix.getVector()); + } + if (once) { + break; + } + } + myWordFeature.setFeatureMatrix(matrixList.getMatrix()); + return myWordFeature; + } + + private void studyDNN(long eventId, int featureIndex, int resIndex, OutBack outBack, boolean isStudy) throws Exception { + List sensoryNerves = nerveManager.getSensoryNerves(); + int size = sensoryNerves.size(); + Map map = new HashMap<>(); + if (resIndex > 0) { + map.put(resIndex + 1, 1f); + } + for (int i = 0; i < size; i++) { + float feature = 0; + if (i == featureIndex) { + feature = 1; + } + sensoryNerves.get(i).postMessage(eventId, feature, isStudy, map, outBack, true, null); + } + } + + + public WordTwoVectorModel start() throws Exception {//开始进行词向量训练 + List sentenceList = sentenceModel.getSentenceList(); + int size = sentenceList.size(); + System.out.println("词嵌入训练启动..."); + int allTimes = studyTimes * size; + int index = 0; + for (int k = 0; k < studyTimes; k++) { + for (int i = 0; i < size; i++) { + index++; + long start = System.currentTimeMillis(); + study(sentenceList.get(i)); + long end = System.currentTimeMillis() - start; + float r = (float) index / allTimes * 100; + String result = String.format("%.6f", r); + System.out.println("size:" + size + ",index:" + i + ",耗时:" + end + ",完成度:" + result + "%"); + } + } + WordTwoVectorModel wordTwoVectorModel = new WordTwoVectorModel(); + wordTwoVectorModel.setModelParameter(nerveManager.getModelParameter()); + wordTwoVectorModel.setWordList(wordList); + //词向量训练结束 + return wordTwoVectorModel; + } + + private void study(String[] word) throws Exception { + int[] indexArray = new int[word.length]; + for (int i = 0; i < word.length; i++) { + int index = getID(word[i]); + indexArray[i] = index; + } + for (int i = 0; i < indexArray.length; i++) { + int index = indexArray[i]; + for (int j = 0; j < indexArray.length; j++) { + if (i != j) { + int resIndex = indexArray[j]; + studyDNN(1, index, resIndex, null, true); + } + } + } + } + + public int getID(String word) { + int index = 0; + int size = wordList.size(); + for (int i = 0; i < size; i++) { + if (wordList.get(i).equals(word)) { + index = i; + break; + } + } + return index; + } +}