This commit is contained in:
2025-08-27 19:58:38 +08:00
parent 9fb3764f6a
commit 32ae8a340a

View File

@@ -0,0 +1,497 @@
/*
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.core.react;
import com.agentsflex.core.llm.ChatContext;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.Llm;
import com.agentsflex.core.llm.StreamResponseListener;
import com.agentsflex.core.llm.functions.Function;
import com.agentsflex.core.llm.functions.Parameter;
import com.agentsflex.core.llm.response.AiMessageResponse;
import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.message.Message;
import com.agentsflex.core.prompt.HistoriesPrompt;
import com.agentsflex.core.util.StringUtil;
import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* ReActAgent 是一个通用的 ReAct 模式 Agent支持 Reasoning + Action 的交互方式。
*/
public class ReActAgent {
private static final Logger log = LoggerFactory.getLogger(ReActAgent.class);
private static final String DEFAULT_PROMPT_TEMPLATE =
"你是一个 ReAct Agent结合 Reasoning推理和 Action行动来解决问题。\n" +
"但在处理用户问题时,请首先判断:\n" +
"1. 如果问题可以通过你的常识或已有知识直接回答 → 请忽略 ReAct 框架,直接输出自然语言回答。\n" +
"2. 如果问题需要调用特定工具才能解决(如查询、计算、获取外部信息等)→ 请严格按照 ReAct 格式响应。\n" +
"\n" +
"如果你选择使用 ReAct 模式,请遵循以下格式:\n" +
"Thought: 描述你对当前问题的理解,包括已知信息和缺失信息,说明你下一步将采取什么行动及其原因。\n" +
"Action: 从下方列出的工具中选择一个合适的工具,仅输出工具名称,不得虚构。\n" +
"Action Input: 使用标准 JSON 格式提供该工具所需的参数,确保字段名与工具描述一致。\n" +
"\n" +
"在 ReAct 模式下,如果你已获得足够信息可以直接回答用户,请输出:\n" +
"Final Answer: [你的回答]\n" +
"\n" +
"注意事项:\n" +
"1. 每次只能选择一个工具并执行一个动作。\n" +
"2. 在未收到工具执行结果前,不要自行假设其输出。\n" +
"3. 不得编造工具或参数,所有工具均列于下方。\n" +
"4. 输出顺序必须为Thought → Action → Action Input。\n" +
"\n" +
"### 可用工具列表:\n" +
"{tools}\n" +
"\n" +
"### 用户问题如下:\n" +
"{user_input}";
// 默认最大迭代次数
private static final int DEFAULT_MAX_ITERATIONS = 5;
private final Llm llm;
private final List<Function> functions;
private final String userQuery;
private boolean streamable = false;
private int maxIterations = DEFAULT_MAX_ITERATIONS;
private String promptTemplate = DEFAULT_PROMPT_TEMPLATE;
private ReActStepParser reActStepParser = ReActStepParser.DEFAULT; // 默认解析器
private final HistoriesPrompt historiesPrompt;
private ChatOptions chatOptions = ChatOptions.DEFAULT;
private ReActMessageBuilder messageBuilder = new ReActMessageBuilder();
// 是否继续执行当 JSON 解析出错时
private boolean continueOnActionJsonParseError = true;
// 是否继续执行当 Function 调用出错时
private boolean continueOnActionInvokeError = true;
// 监听器集合
private final List<ReActAgentListener> listeners = new ArrayList<>();
private int iterationCount = 0;
public ReActAgent(Llm llm, List<Function> functions, String userQuery) {
this.llm = llm;
this.functions = functions;
this.userQuery = userQuery;
this.historiesPrompt = new HistoriesPrompt();
}
public ReActAgent(Llm llm, List<Function> functions, String userQuery, HistoriesPrompt historiesPrompt) {
this.llm = llm;
this.functions = functions;
this.userQuery = userQuery;
this.historiesPrompt = historiesPrompt;
}
/**
* 注册监听器
*/
public void addListener(ReActAgentListener listener) {
listeners.add(listener);
}
/**
* 移除监听器
*/
public void removeListener(ReActAgentListener listener) {
listeners.remove(listener);
}
public Llm getLlm() {
return llm;
}
public List<Function> getFunctions() {
return functions;
}
public String getUserQuery() {
return userQuery;
}
public int getMaxIterations() {
return maxIterations;
}
public void setMaxIterations(int maxIterations) {
this.maxIterations = maxIterations;
}
public String getPromptTemplate() {
return promptTemplate;
}
public void setPromptTemplate(String promptTemplate) {
this.promptTemplate = promptTemplate;
}
public ReActStepParser getReActStepParser() {
return reActStepParser;
}
public void setReActStepParser(ReActStepParser reActStepParser) {
this.reActStepParser = reActStepParser;
}
public List<ReActAgentListener> getListeners() {
return listeners;
}
public boolean isStreamable() {
return streamable;
}
public void setStreamable(boolean streamable) {
this.streamable = streamable;
}
public HistoriesPrompt getHistoriesPrompt() {
return historiesPrompt;
}
public int getIterationCount() {
return iterationCount;
}
public void setIterationCount(int iterationCount) {
this.iterationCount = iterationCount;
}
public boolean isContinueOnActionJsonParseError() {
return continueOnActionJsonParseError;
}
public void setContinueOnActionJsonParseError(boolean continueOnActionJsonParseError) {
this.continueOnActionJsonParseError = continueOnActionJsonParseError;
}
public boolean isContinueOnActionInvokeError() {
return continueOnActionInvokeError;
}
public void setContinueOnActionInvokeError(boolean continueOnActionInvokeError) {
this.continueOnActionInvokeError = continueOnActionInvokeError;
}
public ReActMessageBuilder getMessageBuilder() {
return messageBuilder;
}
public void setMessageBuilder(ReActMessageBuilder messageBuilder) {
this.messageBuilder = messageBuilder;
}
public ChatOptions getChatOptions() {
return chatOptions;
}
public void setChatOptions(ChatOptions chatOptions) {
this.chatOptions = chatOptions;
}
/**
* 运行 ReAct Agent 流程
*/
public void run() {
try {
String toolsDescription = buildToolsDescription(functions);
String prompt = promptTemplate
.replace("{tools}", toolsDescription)
.replace("{user_input}", userQuery);
Message message = messageBuilder.buildStartMessage(prompt, functions, userQuery);
historiesPrompt.addMessage(message);
if (this.isStreamable()) {
startNextReActStepStream();
} else {
startNextReactStepNormal();
}
} catch (Exception e) {
log.error("运行 ReAct Agent 出错:" + e);
notifyOnError(e);
}
}
private void startNextReactStepNormal() {
for (int i = 0; i < maxIterations; i++) {
AiMessageResponse response = llm.chat(historiesPrompt, chatOptions);
String content = response.getMessage().getContent();
historiesPrompt.addMessage(new AiMessage(content));
notifyOnChatResponse(response);
if (isReActAction(content)) {
if (!processReActSteps(content)) {
break;
}
} else if (isFinalAnswer(content)) {
String flag = reActStepParser.getFinalAnswerFlag();
String answer = content.substring(content.indexOf(flag) + flag.length());
notifyOnFinalAnswer(answer);
break;
} else {
// 不是Action
notifyOnNonActionResponse(response);
break;
}
}
}
private void startNextReActStepStream() {
if (iterationCount >= maxIterations) {
notifyOnMaxIterationsReached();
return;
}
iterationCount++;
llm.chatStream(historiesPrompt, new StreamResponseListener() {
@Override
public void onMessage(ChatContext context, AiMessageResponse response) {
notifyOnChatResponseStream(context, response);
}
@Override
public void onStop(ChatContext context) {
AiMessage lastAiMessage = context.getLastAiMessage();
if (lastAiMessage == null) {
notifyOnError(new RuntimeException("没有收到任何回复"));
return;
}
String content = lastAiMessage.getFullContent();
if (StringUtil.noText(content)) {
notifyOnError(new RuntimeException("没有收到任何回复"));
return;
}
// Stream 模式下,消息会自动被添加到 historiesPrompt 中,无需手动添加
// AiMessage aiMessage = new AiMessage(content);
// historiesPrompt.addMessage(aiMessage);
if (isReActAction(content)) {
if (processReActSteps(content)) {
// 递归继续执行下一个 ReAct 步骤
startNextReActStepStream();
}
} else if (isFinalAnswer(content)) {
String flag = reActStepParser.getFinalAnswerFlag();
String answer = content.substring(content.indexOf(flag) + flag.length());
notifyOnFinalAnswer(answer);
} else {
// 不是 Action
notifyOnNonActionResponseStream(context);
}
}
@Override
public void onFailure(ChatContext context, Throwable throwable) {
notifyOnError((Exception) throwable);
}
}, chatOptions);
}
private boolean isFinalAnswer(String content) {
return reActStepParser.isFinalAnswer(content);
}
private boolean isReActAction(String content) {
return reActStepParser.isReActAction(content);
}
// ========== 内部辅助方法 ==========
private String buildToolsDescription(List<Function> functions) {
StringBuilder sb = new StringBuilder();
for (Function function : functions) {
sb.append(" - ").append(function.getName()).append("(");
Parameter[] parameters = function.getParameters();
for (int i = 0; i < parameters.length; i++) {
Parameter param = parameters[i];
sb.append(param.getName()).append(": ").append(param.getType());
if (i < parameters.length - 1) {
sb.append(", ");
}
}
sb.append("): ").append(function.getDescription()).append("\n");
}
return sb.toString();
}
private boolean processReActSteps(String content) {
List<ReActStep> reActSteps = reActStepParser.parse(content);
if (reActSteps.isEmpty()) {
notifyOnStepParseError(content);
return false;
}
for (ReActStep step : reActSteps) {
boolean stepExecuted = false;
for (Function function : functions) {
if (function.getName().equals(step.getAction())) {
try {
notifyOnActionStart(step);
Map<String, Object> parameters;
try {
if (StringUtil.hasText(step.getActionInput())) {
parameters = JSON.parseObject(step.getActionInput());
} else {
parameters = Collections.emptyMap();
}
} catch (Exception e) {
log.error(e.toString(), e);
notifyOnActionJsonParserError(step, e);
if (!continueOnActionJsonParseError) {
return false;
}
Message message = messageBuilder.buildJsonParserErrorMessage(e, step);
historiesPrompt.addMessage(message);
return true; // 继续让 AI 修正
}
Object result = function.invoke(parameters);
notifyOnActionEnd(step, result);
Message message = messageBuilder.buildObservationMessage(step, result);
historiesPrompt.addMessage(message);
stepExecuted = true;
} catch (Exception e) {
log.error(e.toString(), e);
notifyOnActionInvokeError(e);
if (!continueOnActionInvokeError) {
return false;
}
Message message = messageBuilder.buildActionErrorMessage(step, e);
historiesPrompt.addMessage(message);
return true;
}
break;
}
}
if (!stepExecuted) {
notifyOnActionNotMatched(step, functions);
return false;
}
}
return true;
}
// ========== 通知监听器的方法 ==========
private void notifyOnChatResponse(AiMessageResponse response) {
for (ReActAgentListener l : listeners) {
l.onChatResponse(response);
}
}
private void notifyOnNonActionResponse(AiMessageResponse response) {
for (ReActAgentListener l : listeners) {
l.onNonActionResponse(response);
}
}
private void notifyOnNonActionResponseStream(ChatContext context) {
for (ReActAgentListener l : listeners) {
l.onNonActionResponseStream(context);
}
}
private void notifyOnChatResponseStream(ChatContext context, AiMessageResponse response) {
for (ReActAgentListener l : listeners) {
l.onChatResponseStream(context, response);
}
}
private void notifyOnFinalAnswer(String finalAnswer) {
for (ReActAgentListener l : listeners) {
l.onFinalAnswer(finalAnswer);
}
}
private void notifyOnActionStart(ReActStep reActStep) {
for (ReActAgentListener l : listeners) {
l.onActionStart(reActStep);
}
}
private void notifyOnActionEnd(ReActStep reActStep, Object result) {
for (ReActAgentListener l : listeners) {
l.onActionEnd(reActStep, result);
}
}
private void notifyOnMaxIterationsReached() {
for (ReActAgentListener l : listeners) {
l.onMaxIterationsReached();
}
}
private void notifyOnStepParseError(String content) {
for (ReActAgentListener l : listeners) {
l.onStepParseError(content);
}
}
private void notifyOnActionNotMatched(ReActStep step, List<Function> functions) {
for (ReActAgentListener l : listeners) {
l.onActionNotMatched(step, functions);
}
}
private void notifyOnActionJsonParserError(ReActStep step, Exception e) {
for (ReActAgentListener l : listeners) {
l.onActionJsonParserError(step, e);
}
}
private void notifyOnActionInvokeError(Exception e) {
for (ReActAgentListener l : listeners) {
l.onActionInvokeError(e);
}
}
private void notifyOnError(Exception error) {
for (ReActAgentListener l : listeners) {
l.onError(error);
}
}
}