Add File
This commit is contained in:
299
src/main/java/org/dromara/easyai/unet/UNetDecoder.java
Normal file
299
src/main/java/org/dromara/easyai/unet/UNetDecoder.java
Normal file
@@ -0,0 +1,299 @@
|
||||
package org.dromara.easyai.unet;
|
||||
|
||||
import org.dromara.easyai.conv.ConvCount;
|
||||
import org.dromara.easyai.conv.DymStudy;
|
||||
import org.dromara.easyai.entity.ThreeChannelMatrix;
|
||||
import org.dromara.easyai.i.ActiveFunction;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.matrixTools.MatrixOperation;
|
||||
import org.dromara.easyai.nerveEntity.ConvParameter;
|
||||
import org.dromara.easyai.nerveEntity.ConvSize;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @time 2025/3/2 07:55
|
||||
* @des unet解码器
|
||||
*/
|
||||
public class UNetDecoder extends ConvCount {
|
||||
private final ConvParameter convParameter = new ConvParameter();//内存中卷积层模型及临时数据
|
||||
private final MatrixOperation matrixOperation = new MatrixOperation();
|
||||
private final int kerSize;
|
||||
private final int deep;//当前深度
|
||||
private final float studyRate;//学习率
|
||||
private final int channelNo;//通道数
|
||||
private final boolean lastLay;//是否为最后一层
|
||||
private final ActiveFunction activeFunction;
|
||||
private UNetDecoder afterDecoder;//下一个解码器
|
||||
private UNetDecoder beforeDecoder;//上一个解码器
|
||||
private UNetEncoder encoder;//上一个编码器
|
||||
private UNetEncoder myUNetEncoder;//同级编码器
|
||||
private final ConvSize convSize = new ConvSize();
|
||||
private final Cutting cutting;//输出语义切割图像
|
||||
private final float oneConvStudyRate;//
|
||||
private final float gaMa;
|
||||
private final float gMaxTh;
|
||||
private final boolean autoStudyRate;//自动学习率
|
||||
|
||||
public UNetDecoder(int kerSize, int deep, int channelNo, ActiveFunction activeFunction
|
||||
, boolean lastLay, float studyRate, Cutting cutting, float oneConvStudyRate, float gaMa, float gMaxTh, boolean autoStudyRate) throws Exception {
|
||||
this.cutting = cutting;
|
||||
this.autoStudyRate = autoStudyRate;
|
||||
this.gMaxTh = gMaxTh;
|
||||
this.gaMa = gaMa;
|
||||
this.kerSize = kerSize;
|
||||
this.oneConvStudyRate = oneConvStudyRate;
|
||||
this.deep = deep;
|
||||
this.studyRate = studyRate;
|
||||
this.lastLay = lastLay;
|
||||
this.channelNo = channelNo;
|
||||
this.activeFunction = activeFunction;
|
||||
Random random = new Random();
|
||||
List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
|
||||
List<Matrix> dymStudyRateList = convParameter.getDymStudyRateList();
|
||||
List<Matrix> upNeverMatrixList = convParameter.getUpNerveMatrixList();//上卷积采样权重
|
||||
List<Matrix> upDYmStudyRateList = convParameter.getUpDymStudyRateList();
|
||||
List<ConvSize> convSizeList = convParameter.getConvSizeList();
|
||||
for (int i = 0; i < channelNo; i++) {
|
||||
int convSize = kerSize * kerSize;
|
||||
upDYmStudyRateList.add(new Matrix(1, convSize));
|
||||
upNeverMatrixList.add(initUpNervePowerMatrix(random));
|
||||
initNervePowerMatrix(random, nerveMatrixList, dymStudyRateList);
|
||||
convSizeList.add(new ConvSize());
|
||||
}
|
||||
if (lastLay) {
|
||||
List<Float> oneConvPower = new ArrayList<>();
|
||||
List<Float> oneDymStudyRate = new ArrayList<>();
|
||||
for (int i = 0; i < channelNo; i++) {
|
||||
oneConvPower.add(random.nextFloat() / channelNo);
|
||||
oneDymStudyRate.add(0f);
|
||||
}
|
||||
convParameter.setUpOneDymStudyRateList(oneDymStudyRate);
|
||||
convParameter.setUpOneConvPower(oneConvPower);
|
||||
}
|
||||
}
|
||||
|
||||
public ConvParameter getConvParameter() {
|
||||
return convParameter;
|
||||
}
|
||||
|
||||
private ThreeChannelMatrix fillColor(ThreeChannelMatrix picture, int heightSize, int widthSize) throws Exception {
|
||||
int myFaceHeight = picture.getX();
|
||||
int sub = myFaceHeight - heightSize;
|
||||
int fillHeight = sub / 2;//高度差
|
||||
if (fillHeight == 0) {
|
||||
fillHeight = 1;
|
||||
}
|
||||
ThreeChannelMatrix fillMatrix = null;
|
||||
if (sub > 0) {//剪切
|
||||
fillMatrix = picture.cutChannel(fillHeight, 0, heightSize, widthSize);
|
||||
} else if (sub < 0) {//补0
|
||||
fillMatrix = getFaceMatrix(heightSize, widthSize);
|
||||
fillMatrix.fill(Math.abs(fillHeight), 0, picture);
|
||||
}
|
||||
return fillMatrix;
|
||||
}
|
||||
|
||||
private ThreeChannelMatrix getFaceMatrix(int height, int width) {
|
||||
ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
|
||||
Matrix matrixR = new Matrix(height, width);
|
||||
Matrix matrixG = new Matrix(height, width);
|
||||
Matrix matrixB = new Matrix(height, width);
|
||||
Matrix matrixH = new Matrix(height, width);
|
||||
threeChannelMatrix.setX(height);
|
||||
threeChannelMatrix.setY(width);
|
||||
threeChannelMatrix.setMatrixR(matrixR);
|
||||
threeChannelMatrix.setMatrixG(matrixG);
|
||||
threeChannelMatrix.setMatrixB(matrixB);
|
||||
threeChannelMatrix.setH(matrixH);
|
||||
return threeChannelMatrix;
|
||||
}
|
||||
|
||||
private void addFeatures(List<Matrix> encoderFeatures, List<Matrix> myFeatures, boolean study) throws Exception {
|
||||
int size = encoderFeatures.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
addFeature(encoderFeatures.get(i), myFeatures.get(i), study);
|
||||
}
|
||||
}
|
||||
|
||||
private void addFeature(Matrix encoderFeature, Matrix myFeature, boolean study) throws Exception {//获取残差块
|
||||
if (study) {
|
||||
convSize.setXInput(encoderFeature.getX());
|
||||
convSize.setYInput(encoderFeature.getY());
|
||||
}
|
||||
int tx = encoderFeature.getX();
|
||||
int ty = encoderFeature.getY();
|
||||
int x = myFeature.getX();
|
||||
int y = myFeature.getY();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
float encoderValue = 0;
|
||||
if (i < tx && j < ty) {
|
||||
encoderValue = encoderFeature.getNumber(i, j);
|
||||
}
|
||||
float value = (myFeature.getNumber(i, j) + encoderValue) / 2;
|
||||
myFeature.setNub(i, j, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void toThreeChannelMatrix(List<Matrix> features, ThreeChannelMatrix featureE, boolean study, OutBack outBack
|
||||
, ThreeChannelMatrix backGround) throws Exception {
|
||||
int x = features.get(0).getX();
|
||||
int y = features.get(0).getY();
|
||||
List<Float> upOneConvPower = convParameter.getUpOneConvPower();
|
||||
Matrix feature = oneConv(features, upOneConvPower);
|
||||
if (study) {//训练
|
||||
ThreeChannelMatrix sfe = featureE.scale(true, y);//缩放
|
||||
ThreeChannelMatrix fe = fillColor(sfe, x, y);//补0
|
||||
if (fe == null) {
|
||||
fe = sfe;
|
||||
}
|
||||
Matrix he = fe.calculateAvgGrayscale();
|
||||
Matrix errorMatrix = matrixOperation.sub(he, feature);//总误差
|
||||
//先更新分矩阵误差
|
||||
List<Matrix> errorMatrixList = new ArrayList<>();
|
||||
for (int i = 0; i < channelNo; i++) {
|
||||
float power = upOneConvPower.get(i);
|
||||
Matrix error = matrixOperation.mathMulBySelf(errorMatrix, power);
|
||||
errorMatrixList.add(error);
|
||||
}
|
||||
DymStudy dymStudy = new DymStudy(gaMa, gMaxTh, autoStudyRate);
|
||||
backOneConv(errorMatrix, features, upOneConvPower, oneConvStudyRate, convParameter.getUpOneDymStudyRateList(), dymStudy);//更新1v1卷积核
|
||||
backLastError(errorMatrixList);
|
||||
//误差矩阵开始back
|
||||
} else {//输出
|
||||
int mx = backGround.getX();
|
||||
int my = backGround.getY();
|
||||
int startX = (mx - feature.getX()) / 2;
|
||||
int startY = (my - feature.getY()) / 2;
|
||||
Matrix myMatrix = new Matrix(mx, my);
|
||||
for (int i = startX; i < x; i++) {
|
||||
for (int j = startY; j < y; j++) {
|
||||
myMatrix.setNub(i, j, feature.getNumber(i - startX, j - startY));
|
||||
}
|
||||
}
|
||||
ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
|
||||
threeChannelMatrix.setX(x);
|
||||
threeChannelMatrix.setY(y);
|
||||
threeChannelMatrix.setMatrixR(myMatrix);
|
||||
threeChannelMatrix.setMatrixG(myMatrix);
|
||||
threeChannelMatrix.setMatrixB(myMatrix);
|
||||
if (cutting != null) {
|
||||
cutting.cut(backGround, threeChannelMatrix, outBack);
|
||||
} else {
|
||||
outBack.getBackThreeChannelMatrix(threeChannelMatrix);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void backLastError(List<Matrix> errorMatrixList) throws Exception {//最后一层的误差反向传播
|
||||
List<Matrix> errorList = backAllDownConv(convParameter, errorMatrixList, studyRate, activeFunction, channelNo, kerSize, gaMa, gMaxTh
|
||||
, autoStudyRate);
|
||||
sendEncoderError(errorList);//给同级解码器发送误差
|
||||
beforeDecoder.backErrorMatrix(errorList);
|
||||
}
|
||||
|
||||
private void sendEncoderError(List<Matrix> errors) throws Exception {//给同级解码器发送误差
|
||||
List<Matrix> encoderErrors = new ArrayList<>();
|
||||
for (Matrix error : errors) {
|
||||
Matrix encoderError = new Matrix(convSize.getXInput(), convSize.getYInput());
|
||||
int x = convSize.getXInput();
|
||||
int y = convSize.getYInput();
|
||||
int tx = error.getX();
|
||||
int ty = error.getY();
|
||||
for (int i = 0; i < x; i++) {
|
||||
for (int j = 0; j < y; j++) {
|
||||
float value = 0;
|
||||
if (i < tx && j < ty) {
|
||||
value = error.getNumber(i, j) / 2;
|
||||
}
|
||||
encoderError.setNub(i, j, value);
|
||||
}
|
||||
}
|
||||
encoderErrors.add(encoderError);
|
||||
}
|
||||
myUNetEncoder.setDecodeErrorMatrix(encoderErrors);
|
||||
}
|
||||
|
||||
protected void backErrorMatrix(List<Matrix> myErrorMatrixList) throws Exception {//接收解码器误差
|
||||
//退上池化,退上卷积 退下卷积 并返回编码器误差
|
||||
List<Matrix> errorList = backManyUpPooling(myErrorMatrixList);//退上池化
|
||||
List<Matrix> errorMatrixList = backManyUpConv(errorList, kerSize, convParameter, studyRate, activeFunction, gaMa, gMaxTh, autoStudyRate);//退上卷积
|
||||
List<Matrix> backList = backAllDownConv(convParameter, errorMatrixList, studyRate, activeFunction, channelNo, kerSize, gaMa, gMaxTh
|
||||
, autoStudyRate);//退下卷积
|
||||
if (myUNetEncoder != null) {
|
||||
sendEncoderError(backList);//给同级编码器发送误差
|
||||
}
|
||||
if (beforeDecoder != null) {
|
||||
beforeDecoder.backErrorMatrix(backList);
|
||||
} else {//给上一个编码器发送误差
|
||||
encoder.backError(backList);
|
||||
}
|
||||
}
|
||||
|
||||
protected void sendFeature(long eventID, OutBack outBack, ThreeChannelMatrix featureE,
|
||||
List<Matrix> myFeatures, boolean study, ThreeChannelMatrix backGround) throws Exception {
|
||||
if (deep > 1) {
|
||||
List<Matrix> encoderMatrixList = myUNetEncoder.getAfterConvMatrix(eventID);//编码器特征
|
||||
addFeatures(encoderMatrixList, myFeatures, study);
|
||||
}
|
||||
List<Matrix> upConvMatrixList = upConvAndPooling(myFeatures, convParameter, channelNo, activeFunction, kerSize, !lastLay);
|
||||
if (lastLay) {//最后一层解码器
|
||||
toThreeChannelMatrix(upConvMatrixList, featureE, study, outBack, backGround);
|
||||
} else {
|
||||
afterDecoder.sendFeature(eventID, outBack, featureE, upConvMatrixList, study, backGround);
|
||||
}
|
||||
}
|
||||
|
||||
private Matrix initUpNervePowerMatrix(Random random) throws Exception {
|
||||
int convSize = kerSize * kerSize;
|
||||
Matrix nervePowerMatrix = new Matrix(1, convSize);
|
||||
for (int j = 0; j < convSize; j++) {
|
||||
float power = random.nextFloat() / kerSize;
|
||||
nervePowerMatrix.setNub(0, j, power);
|
||||
}
|
||||
return nervePowerMatrix;
|
||||
}
|
||||
|
||||
private void initNervePowerMatrix(Random random, List<Matrix> nervePowerMatrixList, List<Matrix> dymStudyRateList) throws Exception {
|
||||
int convSize = kerSize * kerSize;
|
||||
Matrix nervePowerMatrix = new Matrix(convSize, 1);
|
||||
for (int i = 0; i < convSize; i++) {
|
||||
float power = random.nextFloat() / kerSize;
|
||||
nervePowerMatrix.setNub(i, 0, power);
|
||||
}
|
||||
dymStudyRateList.add(new Matrix(convSize, 1));
|
||||
nervePowerMatrixList.add(nervePowerMatrix);
|
||||
}
|
||||
|
||||
public UNetDecoder getAfterDecoder() {
|
||||
return afterDecoder;
|
||||
}
|
||||
|
||||
public void setAfterDecoder(UNetDecoder afterDecoder) {
|
||||
this.afterDecoder = afterDecoder;
|
||||
}
|
||||
|
||||
public UNetDecoder getBeforeDecoder() {
|
||||
return beforeDecoder;
|
||||
}
|
||||
|
||||
public void setBeforeDecoder(UNetDecoder beforeDecoder) {
|
||||
this.beforeDecoder = beforeDecoder;
|
||||
}
|
||||
|
||||
public UNetEncoder getEncoder() {
|
||||
return encoder;
|
||||
}
|
||||
|
||||
public void setEncoder(UNetEncoder encoder) {
|
||||
this.encoder = encoder;
|
||||
}
|
||||
|
||||
public void setMyUNetEncoder(UNetEncoder myUNetEncoder) {
|
||||
this.myUNetEncoder = myUNetEncoder;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user