diff --git a/src/main/java/org/dromara/easyai/resnet/ConvLay.java b/src/main/java/org/dromara/easyai/resnet/ConvLay.java new file mode 100644 index 0000000..916d501 --- /dev/null +++ b/src/main/java/org/dromara/easyai/resnet/ConvLay.java @@ -0,0 +1,78 @@ +package org.dromara.easyai.resnet; + +import org.dromara.easyai.matrixTools.Matrix; +import org.dromara.easyai.matrixTools.MatrixNorm; +import org.dromara.easyai.resnet.entity.BackParameter; +import org.dromara.easyai.resnet.entity.NormModel; +import org.dromara.easyai.resnet.entity.ResConvModel; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author lidapeng + * @time 2025/4/11 17:18 + * @des 单个卷积层 + */ +public class ConvLay { + private List convPower;//第一层卷积权重 需要作为模型取出 + private List dymStudyRateList;//动态学习率 + private List matrixNormList;//归一化层// 需要作为模型取出 + private final BackParameter backParameter = new BackParameter(); + + public List getDymStudyRateList() { + return dymStudyRateList; + } + + public void setDymStudyRateList(List dymStudyRateList) { + this.dymStudyRateList = dymStudyRateList; + } + + public ResConvModel getModel() { + ResConvModel resConvModel = new ResConvModel(); + List normModelList = new ArrayList<>(); + List convPowerList = new ArrayList<>(); + resConvModel.setNormModelList(normModelList); + resConvModel.setConvPowerModelList(convPowerList); + for (MatrixNorm matrixNorm : matrixNormList) { + normModelList.add(matrixNorm.getModel()); + } + for (Matrix matrix : convPower) { + convPowerList.add(matrix.getMatrixModel()); + } + return resConvModel; + } + + public void insertModel(ResConvModel resConvModel) { + List normModelList = resConvModel.getNormModelList(); + List convPowerList = resConvModel.getConvPowerModelList(); + int normSize = matrixNormList.size(); + for (int i = 0; i < normSize; i++) { + matrixNormList.get(i).insertModel(normModelList.get(i)); + } + int nerveSize = convPower.size(); + for (int i = 0; i < nerveSize; i++) { + convPower.get(i).insertMatrixModel(convPowerList.get(i)); + } + } + + public BackParameter getBackParameter() { + return backParameter; + } + + public List getConvPower() { + return convPower; + } + + public void setConvPower(List convPower) { + this.convPower = convPower; + } + + public List getMatrixNormList() { + return matrixNormList; + } + + public void setMatrixNormList(List matrixNormList) { + this.matrixNormList = matrixNormList; + } +}