Add File
This commit is contained in:
@@ -0,0 +1,76 @@
|
|||||||
|
package org.dromara.easyai.resnet;
|
||||||
|
|
||||||
|
import org.dromara.easyai.i.CustomEncoding;
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lidapeng
|
||||||
|
* @time 2025/4/11 11:17
|
||||||
|
*/
|
||||||
|
public class ResNetConnectionLine implements CustomEncoding {
|
||||||
|
private ResBlock lastBlock;//最后一层残差块
|
||||||
|
private int lastSize;//最后一层的特征大小
|
||||||
|
private int allTimes;//隐层神经数量
|
||||||
|
private int nerveSize;//输入神经元数量
|
||||||
|
private int number = 0;
|
||||||
|
private final List<Float> errorValues = new ArrayList<>();
|
||||||
|
|
||||||
|
public void setLastBlock(ResBlock lastBlock, int lastSize, int allTimes, int nerveSize) {
|
||||||
|
this.lastBlock = lastBlock;
|
||||||
|
this.lastSize = lastSize;
|
||||||
|
this.allTimes = allTimes;
|
||||||
|
this.nerveSize = nerveSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addError(Map<Integer, Float> wg) throws Exception {
|
||||||
|
if (wg.size() != nerveSize) {
|
||||||
|
throw new Exception("线性层回传误差数量与预设值不相等");
|
||||||
|
}
|
||||||
|
for (int i = 1; i <= nerveSize; i++) {
|
||||||
|
float error = wg.get(i);
|
||||||
|
if (number == 1) {
|
||||||
|
errorValues.add(error);
|
||||||
|
} else {
|
||||||
|
float value = error + errorValues.get(i - 1);
|
||||||
|
errorValues.set(i - 1, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void fill(Matrix feature, float value) throws Exception {
|
||||||
|
int x = feature.getX();
|
||||||
|
int y = feature.getY();
|
||||||
|
float myValue = value / (x * y);
|
||||||
|
for (int i = 0; i < x; i++) {
|
||||||
|
for (int j = 0; j < y; j++) {
|
||||||
|
feature.setNub(i, j, myValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void toError() throws Exception {
|
||||||
|
List<Matrix> errorMatrix = new ArrayList<>();
|
||||||
|
for (Float errorValue : errorValues) {
|
||||||
|
Matrix feature = new Matrix(lastSize, lastSize);
|
||||||
|
float error = errorValue;
|
||||||
|
fill(feature, error);
|
||||||
|
errorMatrix.add(feature);
|
||||||
|
}
|
||||||
|
errorValues.clear();
|
||||||
|
lastBlock.backError(errorMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void backError(Map<Integer, Float> wg, int id) throws Exception {
|
||||||
|
number++;
|
||||||
|
addError(wg);
|
||||||
|
if (number == allTimes) {
|
||||||
|
number = 0;
|
||||||
|
toError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user