Add File
This commit is contained in:
102
src/main/java/org/dromara/easyai/transFormer/LineBlock.java
Normal file
102
src/main/java/org/dromara/easyai/transFormer/LineBlock.java
Normal file
@@ -0,0 +1,102 @@
|
||||
package org.dromara.easyai.transFormer;
|
||||
|
||||
import org.dromara.easyai.function.ReLu;
|
||||
import org.dromara.easyai.function.Tanh;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.transFormer.model.LineBlockModel;
|
||||
import org.dromara.easyai.transFormer.nerve.HiddenNerve;
|
||||
import org.dromara.easyai.transFormer.nerve.Nerve;
|
||||
import org.dromara.easyai.transFormer.nerve.OutNerve;
|
||||
import org.dromara.easyai.transFormer.nerve.SoftMax;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class LineBlock {//线性层模块
|
||||
private final List<HiddenNerve> hiddenNerveList = new ArrayList<>();
|
||||
private final List<OutNerve> outNerveList = new ArrayList<>();//输出层
|
||||
private final CodecBlock lastCodecBlock;//最后一层解码块
|
||||
private Matrix allError;
|
||||
private final int featureDimension;
|
||||
private int backNumber = 0;//误差返回次数
|
||||
private final MatrixOperation matrixOperation;
|
||||
|
||||
public LineBlockModel getModel() throws Exception {
|
||||
LineBlockModel lineBlockModel = new LineBlockModel();
|
||||
List<float[][]> hiddenNerveModel = new ArrayList<>();
|
||||
List<float[][]> outNerveModel = new ArrayList<>();
|
||||
for (HiddenNerve hiddenNerve : hiddenNerveList) {
|
||||
hiddenNerveModel.add(hiddenNerve.getModel());
|
||||
}
|
||||
for (OutNerve outNerve : outNerveList) {
|
||||
outNerveModel.add(outNerve.getModel());
|
||||
}
|
||||
lineBlockModel.setHiddenNervesModel(hiddenNerveModel);
|
||||
lineBlockModel.setOutNervesModel(outNerveModel);
|
||||
return lineBlockModel;
|
||||
}
|
||||
|
||||
public void insertModel(LineBlockModel lineBlockModel) throws Exception {
|
||||
List<float[][]> hiddenNerveModel = lineBlockModel.getHiddenNervesModel();
|
||||
List<float[][]> outNerveModel = lineBlockModel.getOutNervesModel();
|
||||
for (int i = 0; i < hiddenNerveList.size(); i++) {
|
||||
hiddenNerveList.get(i).insertModel(hiddenNerveModel.get(i));
|
||||
}
|
||||
for (int i = 0; i < outNerveList.size(); i++) {
|
||||
outNerveList.get(i).insertModel(outNerveModel.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
public LineBlock(int typeNumber, int featureDimension, float studyPoint, CodecBlock lastCodecBlock,
|
||||
boolean showLog, int regularModel, float regular, int coreNumber, float timePunValue) throws Exception {
|
||||
this.featureDimension = featureDimension;
|
||||
this.lastCodecBlock = lastCodecBlock;
|
||||
matrixOperation = new MatrixOperation(coreNumber);
|
||||
SoftMax softMax = new SoftMax(outNerveList, showLog, typeNumber, typeNumber, typeNumber, timePunValue);
|
||||
//隐层
|
||||
List<Nerve> hiddenNerves = new ArrayList<>();
|
||||
for (int i = 0; i < featureDimension; i++) {
|
||||
HiddenNerve hiddenNerve = new HiddenNerve(i + 1, 1, studyPoint, new ReLu(), featureDimension,
|
||||
typeNumber, this, regularModel, regular, coreNumber);
|
||||
hiddenNerves.add(hiddenNerve);
|
||||
hiddenNerveList.add(hiddenNerve);
|
||||
}
|
||||
//输出层
|
||||
List<Nerve> outNerves = new ArrayList<>();
|
||||
for (int i = 0; i < typeNumber; i++) {
|
||||
OutNerve outNerve = new OutNerve(i + 1, studyPoint, featureDimension, featureDimension, typeNumber, softMax
|
||||
, regularModel, regular, coreNumber);
|
||||
outNerve.connectFather(hiddenNerves);
|
||||
outNerves.add(outNerve);
|
||||
outNerveList.add(outNerve);
|
||||
}
|
||||
for (Nerve nerve : hiddenNerves) {
|
||||
nerve.connect(outNerves);
|
||||
}
|
||||
}
|
||||
|
||||
public void sendParameter(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, boolean outAllPro) throws Exception {
|
||||
for (HiddenNerve hiddenNerve : hiddenNerveList) {
|
||||
hiddenNerve.postMessage(eventID, feature, isStudy, outBack, E, outAllPro);
|
||||
}
|
||||
}
|
||||
|
||||
public void backError(long eventID, Matrix errorMatrix) throws Exception {//从线性层返回的误差
|
||||
backNumber++;
|
||||
if (allError == null) {
|
||||
allError = errorMatrix;
|
||||
} else {
|
||||
allError = matrixOperation.add(errorMatrix, allError);
|
||||
}
|
||||
if (backNumber == featureDimension) {
|
||||
backNumber = 0;
|
||||
Matrix error = allError.getSonOfMatrix(0, 0, allError.getX(), allError.getY() - 1);
|
||||
allError = null;
|
||||
//将误差矩阵回传
|
||||
lastCodecBlock.backError(eventID, error);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user