diff --git a/agents-flex-llm/agents-flex-llm-coze/src/main/java/com/agentsflex/llm/coze/CozeLlm.java b/agents-flex-llm/agents-flex-llm-coze/src/main/java/com/agentsflex/llm/coze/CozeLlm.java new file mode 100644 index 0000000..5d147eb --- /dev/null +++ b/agents-flex-llm/agents-flex-llm-coze/src/main/java/com/agentsflex/llm/coze/CozeLlm.java @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com). + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.agentsflex.llm.coze; + +import com.agentsflex.core.document.Document; +import com.agentsflex.core.llm.BaseLlm; +import com.agentsflex.core.llm.ChatOptions; +import com.agentsflex.core.llm.StreamResponseListener; +import com.agentsflex.core.llm.client.HttpClient; +import com.agentsflex.core.llm.embedding.EmbeddingOptions; +import com.agentsflex.core.llm.response.AiMessageResponse; +import com.agentsflex.core.message.AiMessage; +import com.agentsflex.core.message.Message; +import com.agentsflex.core.parser.AiMessageParser; +import com.agentsflex.core.prompt.Prompt; +import com.agentsflex.core.store.VectorData; +import com.agentsflex.core.util.LogUtil; +import com.agentsflex.core.util.StringUtil; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.reflect.Type; +import java.nio.charset.Charset; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.stream.Collectors; + + +/** + * @author yulsh + */ +public class CozeLlm extends BaseLlm { + + private final HttpClient httpClient = new HttpClient(); + private final AiMessageParser aiMessageParser = CozeLlmUtil.getAiMessageParser(); + + public CozeLlm(CozeLlmConfig config) { + super(config); + } + + private Map buildHeader() { + Map headers = new HashMap<>(); + headers.put("Content-Type", "application/json"); + headers.put("Authorization", "Bearer " + config.getApiKey()); + return headers; + } + + private void botChat(Prompt prompt, CozeRequestListener listener, ChatOptions chatOptions, boolean stream) { + String botId = config.getDefaultBotId(); + String userId = config.getDefaultUserId(); + String conversationId = config.getDefaultConversationId(); + Map customVariables = null; + + if (chatOptions instanceof CozeChatOptions) { + CozeChatOptions options = (CozeChatOptions) chatOptions; + botId = StringUtil.hasText(options.getBotId()) ? options.getBotId() : botId; + userId = StringUtil.hasText(options.getUserId()) ? options.getUserId() : userId; + conversationId = StringUtil.hasText(options.getConversationId()) ? options.getConversationId() : conversationId; + customVariables = options.getCustomVariables(); + } + + String payload = CozeLlmUtil.promptToPayload(prompt, botId, userId, customVariables, stream); + String url = config.getEndpoint() + config.getChatApi(); + if (StringUtil.hasText(conversationId)) { + url += "?conversation_id=" + conversationId; + } + String response = httpClient.post(url, buildHeader(), payload); + + if (config.isDebug()) { + LogUtil.println(">>>>receive payload:" + response); + } + + // stream mode + if (stream) { + handleStreamResponse(response, listener); + return; + } + + JSONObject jsonObject = JSON.parseObject(response); + String code = jsonObject.getString("code"); + String error = jsonObject.getString("msg"); + + CozeChatContext cozeChat = jsonObject.getObject("data", (Type) CozeChatContext.class); + + if (!error.isEmpty() && !Objects.equals(code, "0")) { + if (cozeChat == null) { + cozeChat = new CozeChatContext(); + cozeChat.setLlm(this); + cozeChat.setResponse(response); + } + listener.onFailure(cozeChat, new Throwable(error)); + listener.onStop(cozeChat); + return; + } else if (cozeChat != null) { + cozeChat.setLlm(this); + cozeChat.setResponse(response); + } + + // try to check status + int attemptCount = 0; + boolean isCompleted = false; + int maxAttempts = 20; + while (attemptCount < maxAttempts && !isCompleted) { + attemptCount++; + try { + cozeChat = checkStatus(cozeChat); + listener.onMessage(cozeChat); + + isCompleted = Objects.equals(cozeChat.getStatus(), "completed"); + if (isCompleted || attemptCount == maxAttempts) { + listener.onStop(cozeChat); + break; + } + Thread.sleep(1000); + } catch (Exception e) { + listener.onFailure(cozeChat, e.getCause()); + listener.onStop(cozeChat); + Thread.currentThread().interrupt(); + } + } + } + + private void handleStreamResponse(String response, CozeRequestListener listener) { + ByteArrayInputStream inputStream = new ByteArrayInputStream(response.getBytes(Charset.defaultCharset())); + BufferedReader br = new BufferedReader(new InputStreamReader(inputStream, Charset.defaultCharset())); + CozeChatContext context = new CozeChatContext(); + context.setLlm(this); + + // 记录completed消息,在处理完answer消息后再进行处理 + CozeChatContext completedContext = null; + + + List messageList = new ArrayList<>(); + try { + // 在处理消息前,先进行初始化,保持与其他LLM流式处理流程一致 + listener.onStart(context); + + String line; + while ((line = br.readLine()) != null) { + if (line.trim().isEmpty() || !line.startsWith("data:") || line.contains("[DONE]")) { + continue; + } + + //remove "data:" + line = line.substring(5); + JSONObject data = JSON.parseObject(line); + String status = data.getString("status"); + String type = data.getString("type"); + if ("completed".equalsIgnoreCase(status)) { + completedContext = JSON.parseObject(line, CozeChatContext.class); + completedContext.setResponse(line); + continue; + } + // N 条answer,最后一条是完整的 + if ("answer".equalsIgnoreCase(type)) { + AiMessage message = new AiMessage(); + message.setContent(data.getString("content")); + messageList.add(message); + } + } + if (!messageList.isEmpty()) { + // 删除最后一条完整的之后输出 + messageList.remove(messageList.size() - 1); + for (AiMessage m : messageList) { + context.setMessage(m); + listener.onMessage(context); + Thread.sleep(10); + } + } + + if (completedContext != null) { + listener.onStop(completedContext); + } + } catch (IOException ex) { + listener.onFailure(context, ex.getCause()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private CozeChatContext checkStatus(CozeChatContext cozeChat) { + String chatId = cozeChat.getId(); + String conversationId = cozeChat.getConversationId(); + String url = String.format("%s/v3/chat/retrieve?chat_id=%s&conversation_id=%s", config.getEndpoint(), chatId, conversationId); + String response = httpClient.get(url, buildHeader()); + JSONObject resObj = JSON.parseObject(response); + // 需要返回最新的response信息,否则会导致调用方获取不到conversation_id等完整信息 + CozeChatContext cozeChatContext = resObj.getObject("data", (Type) CozeChatContext.class); + cozeChatContext.setResponse(response); + return cozeChatContext; + } + + private JSONArray fetchMessageList(CozeChatContext cozeChat) { + String chatId = cozeChat.getId(); + String conversationId = cozeChat.getConversationId(); + String endpoint = config.getEndpoint(); + String url = String.format("%s/v3/chat/message/list?chat_id=%s&conversation_id=%s", endpoint, chatId, conversationId); + String response = httpClient.get(url, buildHeader()); + JSONObject jsonObject = JSON.parseObject(response); + String code = jsonObject.getString("code"); + String error = jsonObject.getString("msg"); + JSONArray messageList = jsonObject.getJSONArray("data"); + if (!error.isEmpty() && !Objects.equals(code, "0")) { + return null; + } + return messageList; + } + + public AiMessage getChatAnswer(CozeChatContext cozeChat) { + JSONArray messageList = fetchMessageList(cozeChat); + if (messageList == null || messageList.isEmpty()) { + return null; + } + List objects = messageList.stream() + .map(JSONObject.class::cast) + .filter(obj -> "answer".equals(obj.getString("type"))) + .collect(Collectors.toList()); + JSONObject answer = !objects.isEmpty() ? objects.get(0) : null; + if (answer != null) { + /* + * coze上的工作流一个请求可以返回多条消息,需要全部返回,用3个换行符进行分隔 + * 使用3个换行符的原因: + * 若调用方不关心多条消息,不太影响直接展示; + * 若调用方关心多条消息,可以进行分割处理且3个换行符能减少误分隔的概率; + */ + StringBuilder sb = new StringBuilder(answer.getString("content")); + for (int i = 1; i < objects.size(); i++) { + sb.append("\n\n\n").append(objects.get(i).getString("content")); + } + answer.put("usage", cozeChat.getUsage()); + answer.put("content", sb.toString()); + return aiMessageParser.parse(answer); + } + return null; + } + + @Override + public VectorData embed(Document document, EmbeddingOptions options) { + throw new UnsupportedOperationException("Not supported yet."); + } + + + @Override + public AiMessageResponse chat(Prompt prompt, ChatOptions options) { + CountDownLatch latch = new CountDownLatch(1); + Message[] messages = new Message[1]; + String[] responses = new String[1]; + Throwable[] failureThrowable = new Throwable[1]; + + this.botChat(prompt, new CozeRequestListener() { + @Override + public void onStart(CozeChatContext context) { + } + + @Override + public void onMessage(CozeChatContext context) { + boolean isCompleted = Objects.equals(context.getStatus(), "completed"); + if (isCompleted) { + AiMessage answer = getChatAnswer(context); + messages[0] = answer; + responses[0] = context.getResponse(); + } + } + + @Override + public void onFailure(CozeChatContext context, Throwable throwable) { + failureThrowable[0] = throwable; + responses[0] = context.getResponse(); + latch.countDown(); + } + + @Override + public void onStop(CozeChatContext context) { + latch.countDown(); + } + }, options, false); + + try { + latch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + AiMessageResponse response = new AiMessageResponse(prompt, responses[0], (AiMessage) messages[0]); + + if (messages[0] == null || failureThrowable[0] != null) { + response.setError(true); + if (failureThrowable[0] != null) { + response.setErrorMessage(failureThrowable[0].getMessage()); + } + } + + return response; + } + + @Override + public void chatStream(Prompt prompt, StreamResponseListener listener, ChatOptions options) { + this.botChat(prompt, new CozeRequestListener() { + @Override + public void onStart(CozeChatContext context) { + listener.onStart(context); + } + + @Override + public void onMessage(CozeChatContext context) { + AiMessageResponse response = new AiMessageResponse(prompt, context.getResponse(), context.getMessage()); + listener.onMessage(context, response); + } + + @Override + public void onFailure(CozeChatContext context, Throwable throwable) { + listener.onFailure(context, throwable); + } + + @Override + public void onStop(CozeChatContext context) { + listener.onStop(context); + } + }, options, true); + } + +}