diff --git a/backend/apps/ai_model/openai/llm.py b/backend/apps/ai_model/openai/llm.py new file mode 100644 index 0000000..2224868 --- /dev/null +++ b/backend/apps/ai_model/openai/llm.py @@ -0,0 +1,167 @@ +from typing import Dict, Optional, Any, Iterator, cast, Mapping + +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \ + SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk +from langchain_core.outputs import ChatGenerationChunk +from langchain_core.runnables import RunnableConfig, ensure_config +from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.base import _create_usage_metadata + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + id_ = _dict.get("id") + role = cast(str, _dict.get("role")) + content = cast(str, _dict.get("content") or "") + additional_kwargs: dict = {} + if 'reasoning_content' in _dict: + additional_kwargs['reasoning_content'] = _dict.get('reasoning_content') + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) + for rtc in raw_tool_calls + ] + except KeyError: + pass + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content, id=id_) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + id=id_, + tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + ) + elif role in ("system", "developer") or default_class == SystemMessageChunk: + if role == "developer": + additional_kwargs = {"__openai_role__": "developer"} + else: + additional_kwargs = {} + return SystemMessageChunk( + content=content, id=id_, additional_kwargs=additional_kwargs + ) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"], id=id_) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk( + content=content, tool_call_id=_dict["tool_call_id"], id=id_ + ) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role, id=id_) + else: + return default_class(content=content, id=id_) + + +class BaseChatOpenAI(ChatOpenAI): + usage_metadata: dict = {} + + # custom_get_token_ids = custom_get_token_ids + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + kwargs['stream_usage'] = True + for chunk in super()._stream(*args, **kwargs): + if chunk.message.usage_metadata is not None: + self.usage_metadata = chunk.message.usage_metadata + yield chunk + + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: type, + base_generation_info: Optional[dict], + ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream + return None + token_usage = chunk.get("usage") + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) + + usage_metadata: Optional[UsageMetadata] = ( + _create_usage_metadata(token_usage) if token_usage and token_usage.get("prompt_tokens") else None + ) + if len(choices) == 0: + # logprobs is implicitly None + generation_chunk = ChatGenerationChunk( + message=default_chunk_class(content="", usage_metadata=usage_metadata) + ) + return generation_chunk + + choice = choices[0] + if choice["delta"] is None: + return None + + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if base_generation_info else {} + + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if usage_metadata and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = usage_metadata + + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + return generation_chunk + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = ensure_config(config) + chat_result = cast( + "ChatGeneration", + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + run_id=config.pop("run_id", None), + **kwargs, + ).generations[0][0], + + ).message + + self.usage_metadata = chat_result.response_metadata[ + 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata + return chat_result