Add File
This commit is contained in:
203
src/main/java/org/dromara/easyai/unet/UNetManager.java
Normal file
203
src/main/java/org/dromara/easyai/unet/UNetManager.java
Normal file
@@ -0,0 +1,203 @@
|
||||
package org.dromara.easyai.unet;
|
||||
|
||||
|
||||
import org.dromara.easyai.config.UNetConfig;
|
||||
import org.dromara.easyai.conv.ConvCount;
|
||||
import org.dromara.easyai.function.ReLu;
|
||||
import org.dromara.easyai.function.Tanh;
|
||||
import org.dromara.easyai.matrixTools.Matrix;
|
||||
import org.dromara.easyai.nerveEntity.ConvParameter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @time 2025/2/25 20:19
|
||||
* @des UNet网络管理器
|
||||
*/
|
||||
public class UNetManager extends ConvCount {
|
||||
private final List<UNetEncoder> encoderList = new ArrayList<>();
|
||||
private final List<UNetDecoder> decoderList = new ArrayList<>();
|
||||
private final int kernLen;
|
||||
private final int channelNo;
|
||||
private final int deep;
|
||||
private final float studyRate;
|
||||
private final float oneStudyRate;//1v1卷积权重学习率
|
||||
private UNetInput input;//输入类
|
||||
private final float gaMa;
|
||||
private final float gMaxTh;
|
||||
private final boolean auTo;
|
||||
|
||||
public UNetInput getInput() {
|
||||
return input;
|
||||
}
|
||||
|
||||
public UNetManager(UNetConfig uNetConfig) throws Exception {
|
||||
int xSize = uNetConfig.getXSize();
|
||||
int ySize = uNetConfig.getYSize();
|
||||
gaMa = uNetConfig.getGaMa();
|
||||
gMaxTh = uNetConfig.getGMaxTh();
|
||||
auTo = uNetConfig.isAuto();
|
||||
int minFeatureValue = uNetConfig.getMinFeatureValue();
|
||||
this.kernLen = uNetConfig.getKerSize();
|
||||
this.channelNo = uNetConfig.getChannelNo();
|
||||
this.studyRate = uNetConfig.getStudyRate();
|
||||
this.oneStudyRate = uNetConfig.getOneStudyRate();
|
||||
this.deep = getConvMyDep(xSize, ySize, kernLen, minFeatureValue);//编码器深度深度
|
||||
if (deep > 1) {
|
||||
initEncoder(xSize, ySize);//初始化编码器
|
||||
initDecoder(uNetConfig.isCutting(), uNetConfig.getCutTh());
|
||||
connectionCoder();
|
||||
} else {
|
||||
throw new Exception("minFeatureValue 设置的值太大了");
|
||||
}
|
||||
}
|
||||
|
||||
private float[] getFValue(Float[] values) {
|
||||
float[] fValue = new float[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
fValue[i] = values[i];
|
||||
}
|
||||
return fValue;
|
||||
}
|
||||
|
||||
private Float[] getValue(float[] values) {
|
||||
Float[] result = new Float[values.length];
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
result[i] = values[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public void insertModel(UNetModel uNetModel) throws Exception {
|
||||
List<ConvModel> encoderModel = uNetModel.getEncoderModels();
|
||||
List<ConvModel> decoderModel = uNetModel.getDecoderModels();
|
||||
if (encoderModel.size() != deep) {
|
||||
throw new Exception("模型深度不匹配");
|
||||
}
|
||||
for (int i = 0; i < deep; i++) {
|
||||
ConvParameter convParameter = encoderList.get(i).getConvParameter();
|
||||
List<Matrix> matrixList = convParameter.getNerveMatrixList();
|
||||
ConvModel convModel = encoderModel.get(i);
|
||||
List<Float[]> downPowers = convModel.getDownNervePower();
|
||||
List<List<Float>> oneNervePower = convModel.getOneNervePowerList();
|
||||
convParameter.setOneConvPower(oneNervePower);
|
||||
for (int j = 0; j < matrixList.size(); j++) {
|
||||
Matrix matrix = matrixList.get(j);
|
||||
float[] power = getFValue(downPowers.get(j));
|
||||
matrix.setCudaMatrix(power, matrix.getX(), matrix.getY());
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < deep + 1; i++) {
|
||||
ConvParameter convParameter = decoderList.get(i).getConvParameter();
|
||||
List<Matrix> matrixList = convParameter.getNerveMatrixList();
|
||||
ConvModel convModel = decoderModel.get(i);
|
||||
List<Float[]> downPowers = convModel.getDownNervePower();
|
||||
List<Float[]> upNervePowerModel = convModel.getUpNervePower();
|
||||
convParameter.setUpOneConvPower(convModel.getOneNervePower());
|
||||
List<Matrix> upNervePowers = convParameter.getUpNerveMatrixList();
|
||||
for (int j = 0; j < upNervePowerModel.size(); j++) {
|
||||
float[] upPower = getFValue(upNervePowerModel.get(j));
|
||||
Matrix upNervePower = upNervePowers.get(j);
|
||||
upNervePower.setCudaMatrix(upPower, upNervePower.getX(), upNervePower.getY());
|
||||
}
|
||||
for (int j = 0; j < matrixList.size(); j++) {
|
||||
Matrix matrix = matrixList.get(j);
|
||||
float[] power = getFValue(downPowers.get(j));
|
||||
matrix.setCudaMatrix(power, matrix.getX(), matrix.getY());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public UNetModel getModel() {
|
||||
UNetModel unetModel = new UNetModel();
|
||||
List<ConvModel> encoderModel = new ArrayList<>();
|
||||
List<ConvModel> decoderModel = new ArrayList<>();
|
||||
unetModel.setEncoderModels(encoderModel);
|
||||
unetModel.setDecoderModels(decoderModel);
|
||||
for (int i = 0; i < deep; i++) {//遍历每一层
|
||||
ConvModel convModel = new ConvModel();
|
||||
encoderModel.add(convModel);
|
||||
ConvParameter convParameter = encoderList.get(i).getConvParameter();
|
||||
List<Float[]> downNervePower = new ArrayList<>();
|
||||
convModel.setDownNervePower(downNervePower);
|
||||
List<List<Float>> onePowers = convParameter.getOneConvPower();
|
||||
if (onePowers != null && !onePowers.isEmpty()) {
|
||||
convModel.setOneNervePowerList(onePowers);
|
||||
}
|
||||
List<Matrix> downNerveMatrix = convParameter.getNerveMatrixList();//下采样卷积权重
|
||||
for (Matrix nerveMatrix : downNerveMatrix) {
|
||||
Float[] downPower = getValue(nerveMatrix.getCudaMatrix());
|
||||
downNervePower.add(downPower);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < deep + 1; i++) {
|
||||
ConvModel convModel = new ConvModel();
|
||||
decoderModel.add(convModel);
|
||||
ConvParameter convParameter = decoderList.get(i).getConvParameter();
|
||||
convModel.setOneNervePower(convParameter.getUpOneConvPower());
|
||||
List<Float[]> downNervePower = new ArrayList<>();
|
||||
convModel.setDownNervePower(downNervePower);
|
||||
List<Matrix> upNerveMatrix = convParameter.getUpNerveMatrixList();
|
||||
List<Float[]> upNervePower = new ArrayList<>();
|
||||
for (Matrix upMatrix : upNerveMatrix) {
|
||||
upNervePower.add(getValue(upMatrix.getCudaMatrix()));
|
||||
}
|
||||
convModel.setUpNervePower(upNervePower);
|
||||
List<Matrix> downNerveMatrix = convParameter.getNerveMatrixList();
|
||||
for (Matrix nerveMatrix : downNerveMatrix) {
|
||||
Float[] downPower = getValue(nerveMatrix.getCudaMatrix());
|
||||
downNervePower.add(downPower);
|
||||
}
|
||||
}
|
||||
return unetModel;
|
||||
}
|
||||
|
||||
private void connectionCoder() {//链接编解码器
|
||||
UNetEncoder lastUNetEncoder = encoderList.get(deep - 1);//最后一层编码器
|
||||
UNetDecoder firstUNetDecoder = decoderList.get(0);//第一层解码器
|
||||
lastUNetEncoder.setDecoder(firstUNetDecoder);
|
||||
firstUNetDecoder.setEncoder(lastUNetEncoder);
|
||||
for (int i = 0; i < deep; i++) {
|
||||
UNetEncoder uNetEncoder = encoderList.get(i);
|
||||
UNetDecoder uNetDecoder = decoderList.get(deep - i);
|
||||
uNetDecoder.setMyUNetEncoder(uNetEncoder);//绑定统计编码器
|
||||
}
|
||||
}
|
||||
|
||||
private void initDecoder(boolean cutting, float cutTh) throws Exception {
|
||||
Cutting myCut = null;
|
||||
if (cutting) {
|
||||
myCut = new Cutting(cutTh);
|
||||
}
|
||||
for (int i = 0; i < deep + 1; i++) {
|
||||
UNetDecoder uNetDecoder = new UNetDecoder(kernLen, i + 1, channelNo, new Tanh(),
|
||||
i == deep, studyRate, myCut, oneStudyRate, gaMa, gMaxTh, auTo);
|
||||
decoderList.add(uNetDecoder);
|
||||
}
|
||||
for (int i = 0; i < deep; i++) {
|
||||
UNetDecoder uNetDecoder = decoderList.get(i);
|
||||
UNetDecoder nextUNetDecoder = decoderList.get(i + 1);
|
||||
uNetDecoder.setAfterDecoder(nextUNetDecoder);
|
||||
nextUNetDecoder.setBeforeDecoder(uNetDecoder);
|
||||
}
|
||||
}
|
||||
|
||||
private void initEncoder(int xSize, int ySize) throws Exception {
|
||||
for (int i = 0; i < deep; i++) {
|
||||
UNetEncoder uNetEncoder = new UNetEncoder(kernLen, channelNo, i + 1, new ReLu(), studyRate
|
||||
, xSize, ySize, oneStudyRate, gaMa, gMaxTh, auTo);
|
||||
if (i == 0) {
|
||||
input = new UNetInput(uNetEncoder);
|
||||
}
|
||||
encoderList.add(uNetEncoder);
|
||||
}
|
||||
for (int i = 0; i < deep - 1; i++) {
|
||||
UNetEncoder uNetEncoder = encoderList.get(i);
|
||||
UNetEncoder nextUNetEncoder = encoderList.get(i + 1);
|
||||
uNetEncoder.setAfterEncoder(nextUNetEncoder);
|
||||
nextUNetEncoder.setBeforeEncoder(uNetEncoder);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user