From 32ae8a340a4c255190a64c68d46a847fff404bec Mon Sep 17 00:00:00 2001 From: 0007 <0007@qq.com> Date: Wed, 27 Aug 2025 19:58:38 +0800 Subject: [PATCH] Add File --- .../com/agentsflex/core/react/ReActAgent.java | 497 ++++++++++++++++++ 1 file changed, 497 insertions(+) create mode 100644 agents-flex-core/src/main/java/com/agentsflex/core/react/ReActAgent.java diff --git a/agents-flex-core/src/main/java/com/agentsflex/core/react/ReActAgent.java b/agents-flex-core/src/main/java/com/agentsflex/core/react/ReActAgent.java new file mode 100644 index 0000000..49d37fe --- /dev/null +++ b/agents-flex-core/src/main/java/com/agentsflex/core/react/ReActAgent.java @@ -0,0 +1,497 @@ +/* + * Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com). + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.agentsflex.core.react; + +import com.agentsflex.core.llm.ChatContext; +import com.agentsflex.core.llm.ChatOptions; +import com.agentsflex.core.llm.Llm; +import com.agentsflex.core.llm.StreamResponseListener; +import com.agentsflex.core.llm.functions.Function; +import com.agentsflex.core.llm.functions.Parameter; +import com.agentsflex.core.llm.response.AiMessageResponse; +import com.agentsflex.core.message.AiMessage; +import com.agentsflex.core.message.Message; +import com.agentsflex.core.prompt.HistoriesPrompt; +import com.agentsflex.core.util.StringUtil; +import com.alibaba.fastjson.JSON; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * ReActAgent 是一个通用的 ReAct 模式 Agent,支持 Reasoning + Action 的交互方式。 + */ +public class ReActAgent { + + private static final Logger log = LoggerFactory.getLogger(ReActAgent.class); + + private static final String DEFAULT_PROMPT_TEMPLATE = + "你是一个 ReAct Agent,结合 Reasoning(推理)和 Action(行动)来解决问题。\n" + + "但在处理用户问题时,请首先判断:\n" + + "1. 如果问题可以通过你的常识或已有知识直接回答 → 请忽略 ReAct 框架,直接输出自然语言回答。\n" + + "2. 如果问题需要调用特定工具才能解决(如查询、计算、获取外部信息等)→ 请严格按照 ReAct 格式响应。\n" + + "\n" + + "如果你选择使用 ReAct 模式,请遵循以下格式:\n" + + "Thought: 描述你对当前问题的理解,包括已知信息和缺失信息,说明你下一步将采取什么行动及其原因。\n" + + "Action: 从下方列出的工具中选择一个合适的工具,仅输出工具名称,不得虚构。\n" + + "Action Input: 使用标准 JSON 格式提供该工具所需的参数,确保字段名与工具描述一致。\n" + + "\n" + + "在 ReAct 模式下,如果你已获得足够信息可以直接回答用户,请输出:\n" + + "Final Answer: [你的回答]\n" + + "\n" + + "注意事项:\n" + + "1. 每次只能选择一个工具并执行一个动作。\n" + + "2. 在未收到工具执行结果前,不要自行假设其输出。\n" + + "3. 不得编造工具或参数,所有工具均列于下方。\n" + + "4. 输出顺序必须为:Thought → Action → Action Input。\n" + + "\n" + + "### 可用工具列表:\n" + + "{tools}\n" + + "\n" + + "### 用户问题如下:\n" + + "{user_input}"; + + // 默认最大迭代次数 + private static final int DEFAULT_MAX_ITERATIONS = 5; + + private final Llm llm; + private final List functions; + private final String userQuery; + + private boolean streamable = false; + private int maxIterations = DEFAULT_MAX_ITERATIONS; + private String promptTemplate = DEFAULT_PROMPT_TEMPLATE; + private ReActStepParser reActStepParser = ReActStepParser.DEFAULT; // 默认解析器 + private final HistoriesPrompt historiesPrompt; + private ChatOptions chatOptions = ChatOptions.DEFAULT; + private ReActMessageBuilder messageBuilder = new ReActMessageBuilder(); + + + // 是否继续执行当 JSON 解析出错时 + private boolean continueOnActionJsonParseError = true; + + // 是否继续执行当 Function 调用出错时 + private boolean continueOnActionInvokeError = true; + + // 监听器集合 + private final List listeners = new ArrayList<>(); + + private int iterationCount = 0; + + public ReActAgent(Llm llm, List functions, String userQuery) { + this.llm = llm; + this.functions = functions; + this.userQuery = userQuery; + this.historiesPrompt = new HistoriesPrompt(); + } + + public ReActAgent(Llm llm, List functions, String userQuery, HistoriesPrompt historiesPrompt) { + this.llm = llm; + this.functions = functions; + this.userQuery = userQuery; + this.historiesPrompt = historiesPrompt; + } + + /** + * 注册监听器 + */ + public void addListener(ReActAgentListener listener) { + listeners.add(listener); + } + + /** + * 移除监听器 + */ + public void removeListener(ReActAgentListener listener) { + listeners.remove(listener); + } + + public Llm getLlm() { + return llm; + } + + public List getFunctions() { + return functions; + } + + public String getUserQuery() { + return userQuery; + } + + public int getMaxIterations() { + return maxIterations; + } + + public void setMaxIterations(int maxIterations) { + this.maxIterations = maxIterations; + } + + public String getPromptTemplate() { + return promptTemplate; + } + + public void setPromptTemplate(String promptTemplate) { + this.promptTemplate = promptTemplate; + } + + public ReActStepParser getReActStepParser() { + return reActStepParser; + } + + public void setReActStepParser(ReActStepParser reActStepParser) { + this.reActStepParser = reActStepParser; + } + + public List getListeners() { + return listeners; + } + + public boolean isStreamable() { + return streamable; + } + + public void setStreamable(boolean streamable) { + this.streamable = streamable; + } + + public HistoriesPrompt getHistoriesPrompt() { + return historiesPrompt; + } + + public int getIterationCount() { + return iterationCount; + } + + public void setIterationCount(int iterationCount) { + this.iterationCount = iterationCount; + } + + public boolean isContinueOnActionJsonParseError() { + return continueOnActionJsonParseError; + } + + public void setContinueOnActionJsonParseError(boolean continueOnActionJsonParseError) { + this.continueOnActionJsonParseError = continueOnActionJsonParseError; + } + + public boolean isContinueOnActionInvokeError() { + return continueOnActionInvokeError; + } + + public void setContinueOnActionInvokeError(boolean continueOnActionInvokeError) { + this.continueOnActionInvokeError = continueOnActionInvokeError; + } + + public ReActMessageBuilder getMessageBuilder() { + return messageBuilder; + } + + public void setMessageBuilder(ReActMessageBuilder messageBuilder) { + this.messageBuilder = messageBuilder; + } + + public ChatOptions getChatOptions() { + return chatOptions; + } + + public void setChatOptions(ChatOptions chatOptions) { + this.chatOptions = chatOptions; + } + + /** + * 运行 ReAct Agent 流程 + */ + public void run() { + try { + String toolsDescription = buildToolsDescription(functions); + String prompt = promptTemplate + .replace("{tools}", toolsDescription) + .replace("{user_input}", userQuery); + + Message message = messageBuilder.buildStartMessage(prompt, functions, userQuery); + historiesPrompt.addMessage(message); + + if (this.isStreamable()) { + startNextReActStepStream(); + } else { + startNextReactStepNormal(); + } + + } catch (Exception e) { + log.error("运行 ReAct Agent 出错:" + e); + notifyOnError(e); + } + } + + private void startNextReactStepNormal() { + for (int i = 0; i < maxIterations; i++) { + AiMessageResponse response = llm.chat(historiesPrompt, chatOptions); + String content = response.getMessage().getContent(); + historiesPrompt.addMessage(new AiMessage(content)); + + notifyOnChatResponse(response); + + if (isReActAction(content)) { + if (!processReActSteps(content)) { + break; + } + } else if (isFinalAnswer(content)) { + String flag = reActStepParser.getFinalAnswerFlag(); + String answer = content.substring(content.indexOf(flag) + flag.length()); + notifyOnFinalAnswer(answer); + break; + } else { + // 不是Action + notifyOnNonActionResponse(response); + break; + } + } + } + + private void startNextReActStepStream() { + if (iterationCount >= maxIterations) { + notifyOnMaxIterationsReached(); + return; + } + + iterationCount++; + + llm.chatStream(historiesPrompt, new StreamResponseListener() { + + @Override + public void onMessage(ChatContext context, AiMessageResponse response) { + notifyOnChatResponseStream(context, response); + } + + @Override + public void onStop(ChatContext context) { + AiMessage lastAiMessage = context.getLastAiMessage(); + if (lastAiMessage == null) { + notifyOnError(new RuntimeException("没有收到任何回复")); + return; + } + + String content = lastAiMessage.getFullContent(); + if (StringUtil.noText(content)) { + notifyOnError(new RuntimeException("没有收到任何回复")); + return; + } + + // Stream 模式下,消息会自动被添加到 historiesPrompt 中,无需手动添加 + // AiMessage aiMessage = new AiMessage(content); + // historiesPrompt.addMessage(aiMessage); + + if (isReActAction(content)) { + if (processReActSteps(content)) { + // 递归继续执行下一个 ReAct 步骤 + startNextReActStepStream(); + } + } else if (isFinalAnswer(content)) { + String flag = reActStepParser.getFinalAnswerFlag(); + String answer = content.substring(content.indexOf(flag) + flag.length()); + notifyOnFinalAnswer(answer); + } else { + // 不是 Action + notifyOnNonActionResponseStream(context); + } + } + + @Override + public void onFailure(ChatContext context, Throwable throwable) { + notifyOnError((Exception) throwable); + } + }, chatOptions); + } + + + private boolean isFinalAnswer(String content) { + return reActStepParser.isFinalAnswer(content); + } + + private boolean isReActAction(String content) { + return reActStepParser.isReActAction(content); + } + + // ========== 内部辅助方法 ========== + + private String buildToolsDescription(List functions) { + StringBuilder sb = new StringBuilder(); + for (Function function : functions) { + sb.append(" - ").append(function.getName()).append("("); + Parameter[] parameters = function.getParameters(); + for (int i = 0; i < parameters.length; i++) { + Parameter param = parameters[i]; + sb.append(param.getName()).append(": ").append(param.getType()); + if (i < parameters.length - 1) { + sb.append(", "); + } + } + sb.append("): ").append(function.getDescription()).append("\n"); + } + return sb.toString(); + } + + + private boolean processReActSteps(String content) { + List reActSteps = reActStepParser.parse(content); + if (reActSteps.isEmpty()) { + notifyOnStepParseError(content); + return false; + } + + for (ReActStep step : reActSteps) { + boolean stepExecuted = false; + for (Function function : functions) { + if (function.getName().equals(step.getAction())) { + try { + notifyOnActionStart(step); + + Map parameters; + try { + if (StringUtil.hasText(step.getActionInput())) { + parameters = JSON.parseObject(step.getActionInput()); + } else { + parameters = Collections.emptyMap(); + } + } catch (Exception e) { + log.error(e.toString(), e); + notifyOnActionJsonParserError(step, e); + + if (!continueOnActionJsonParseError) { + return false; + } + + Message message = messageBuilder.buildJsonParserErrorMessage(e, step); + historiesPrompt.addMessage(message); + return true; // 继续让 AI 修正 + } + + Object result = function.invoke(parameters); + notifyOnActionEnd(step, result); + + Message message = messageBuilder.buildObservationMessage(step, result); + historiesPrompt.addMessage(message); + stepExecuted = true; + } catch (Exception e) { + log.error(e.toString(), e); + notifyOnActionInvokeError(e); + + if (!continueOnActionInvokeError) { + return false; + } + + Message message = messageBuilder.buildActionErrorMessage(step, e); + historiesPrompt.addMessage(message); + return true; + } + break; + } + } + + if (!stepExecuted) { + notifyOnActionNotMatched(step, functions); + return false; + } + } + + + return true; + } + + + // ========== 通知监听器的方法 ========== + private void notifyOnChatResponse(AiMessageResponse response) { + for (ReActAgentListener l : listeners) { + l.onChatResponse(response); + } + } + + private void notifyOnNonActionResponse(AiMessageResponse response) { + for (ReActAgentListener l : listeners) { + l.onNonActionResponse(response); + } + } + + private void notifyOnNonActionResponseStream(ChatContext context) { + for (ReActAgentListener l : listeners) { + l.onNonActionResponseStream(context); + } + } + + private void notifyOnChatResponseStream(ChatContext context, AiMessageResponse response) { + for (ReActAgentListener l : listeners) { + l.onChatResponseStream(context, response); + } + } + + private void notifyOnFinalAnswer(String finalAnswer) { + for (ReActAgentListener l : listeners) { + l.onFinalAnswer(finalAnswer); + } + } + + private void notifyOnActionStart(ReActStep reActStep) { + for (ReActAgentListener l : listeners) { + l.onActionStart(reActStep); + } + } + + private void notifyOnActionEnd(ReActStep reActStep, Object result) { + for (ReActAgentListener l : listeners) { + l.onActionEnd(reActStep, result); + } + } + + private void notifyOnMaxIterationsReached() { + for (ReActAgentListener l : listeners) { + l.onMaxIterationsReached(); + } + } + + private void notifyOnStepParseError(String content) { + for (ReActAgentListener l : listeners) { + l.onStepParseError(content); + } + } + + private void notifyOnActionNotMatched(ReActStep step, List functions) { + for (ReActAgentListener l : listeners) { + l.onActionNotMatched(step, functions); + } + } + + private void notifyOnActionJsonParserError(ReActStep step, Exception e) { + for (ReActAgentListener l : listeners) { + l.onActionJsonParserError(step, e); + } + } + + private void notifyOnActionInvokeError(Exception e) { + for (ReActAgentListener l : listeners) { + l.onActionInvokeError(e); + } + } + + private void notifyOnError(Exception error) { + for (ReActAgentListener l : listeners) { + l.onError(error); + } + } +}