Add File
This commit is contained in:
@@ -0,0 +1,76 @@
|
|||||||
|
package org.dromara.easyai.rnnNerveEntity;
|
||||||
|
|
||||||
|
import org.dromara.easyai.matrixTools.Matrix;
|
||||||
|
import org.dromara.easyai.i.ActiveFunction;
|
||||||
|
import org.dromara.easyai.i.OutBack;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lidapeng
|
||||||
|
* 输出神经元
|
||||||
|
* @date 11:25 上午 2019/12/21
|
||||||
|
*/
|
||||||
|
public class OutNerve extends Nerve {
|
||||||
|
private final boolean isShowLog;
|
||||||
|
private final boolean isSoftMax;
|
||||||
|
|
||||||
|
public OutNerve(int id, int upNub, int downNub, float studyPoint, boolean init,
|
||||||
|
ActiveFunction activeFunction, boolean isShowLog, int rzType, float lParam, boolean isSoftMax) throws Exception {
|
||||||
|
super(id, upNub, "OutNerve", downNub, studyPoint, init,
|
||||||
|
activeFunction, rzType, lParam, 0);
|
||||||
|
this.isShowLog = isShowLog;
|
||||||
|
this.isSoftMax = isSoftMax;
|
||||||
|
}
|
||||||
|
|
||||||
|
void getGBySoftMax(float g, long eventId) throws Exception {//接收softMax层回传梯度
|
||||||
|
gradient = g;
|
||||||
|
updatePower(eventId);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void input(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E
|
||||||
|
, OutBack outBack, boolean isEmbedding, Matrix rnnMatrix) throws Exception {
|
||||||
|
boolean allReady = insertParameter(eventId, parameter, false);
|
||||||
|
if (allReady) {//参数齐了,开始计算 sigma - threshold
|
||||||
|
float sigma = calculation(eventId, false);
|
||||||
|
if (isSoftMax) {
|
||||||
|
if (!isStudy) {
|
||||||
|
destoryParameter(eventId);
|
||||||
|
}
|
||||||
|
sendMessage(eventId, sigma, isStudy, E, outBack, false, rnnMatrix);
|
||||||
|
} else {
|
||||||
|
float out = activeFunction.function(sigma);
|
||||||
|
if (isStudy) {//输出结果并进行BP调整权重及阈值
|
||||||
|
outNub = out;
|
||||||
|
if (E.containsKey(getId())) {
|
||||||
|
this.E = E.get(getId());
|
||||||
|
} else {
|
||||||
|
this.E = 0;
|
||||||
|
}
|
||||||
|
if (isShowLog) {
|
||||||
|
System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + getId());
|
||||||
|
}
|
||||||
|
gradient = outGradient();//当前梯度变化
|
||||||
|
//调整权重 修改阈值 并进行反向传播
|
||||||
|
updatePower(eventId);
|
||||||
|
} else {//获取最后输出
|
||||||
|
destoryParameter(eventId);
|
||||||
|
if (outBack != null) {
|
||||||
|
outBack.getBack(out, getId(), eventId);
|
||||||
|
} else {
|
||||||
|
throw new Exception("not find outBack");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private float outGradient() {//生成输出层神经元梯度变化
|
||||||
|
//上层神经元输入值 * 当前神经元梯度*学习率 =该上层输入的神经元权重变化
|
||||||
|
//当前梯度神经元梯度变化 *学习旅 * -1 = 当前神经元阈值变化
|
||||||
|
return activeFunction.functionG(outNub) * (E - outNub);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user