From 3375478b9f9890dce9230443d558bb758828a524 Mon Sep 17 00:00:00 2001 From: inter Date: Mon, 8 Sep 2025 16:36:25 +0800 Subject: [PATCH] Add File --- backend/apps/ai_model/model_factory.py | 170 +++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 backend/apps/ai_model/model_factory.py diff --git a/backend/apps/ai_model/model_factory.py b/backend/apps/ai_model/model_factory.py new file mode 100644 index 0000000..03479fd --- /dev/null +++ b/backend/apps/ai_model/model_factory.py @@ -0,0 +1,170 @@ +from functools import lru_cache +import json +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, Type + +from langchain.chat_models.base import BaseChatModel +from pydantic import BaseModel +from sqlmodel import Session, select + +from apps.ai_model.openai.llm import BaseChatOpenAI +from apps.system.models.system_model import AiModelDetail +from common.core.db import engine +from common.utils.crypto import sqlbot_decrypt +from common.utils.utils import prepare_model_arg +from langchain_community.llms import VLLMOpenAI +from langchain_openai import AzureChatOpenAI +# from langchain_community.llms import Tongyi, VLLM + +class LLMConfig(BaseModel): + """Base configuration class for large language models""" + model_id: Optional[int] = None + model_type: str # Model type: openai/tongyi/vllm etc. + model_name: str # Specific model name + api_key: Optional[str] = None + api_base_url: Optional[str] = None + additional_params: Dict[str, Any] = {} + class Config: + frozen = True + + def __hash__(self): + if hasattr(self, 'additional_params') and isinstance(self.additional_params, dict): + hashable_params = frozenset((k, tuple(v) if isinstance(v, (list, dict)) else v) + for k, v in self.additional_params.items()) + else: + hashable_params = None + + return hash(( + self.model_id, + self.model_type, + self.model_name, + self.api_key, + self.api_base_url, + hashable_params + )) + + +class BaseLLM(ABC): + """Abstract base class for large language models""" + + def __init__(self, config: LLMConfig): + self.config = config + self._llm = self._init_llm() + + @abstractmethod + def _init_llm(self) -> BaseChatModel: + """Initialize specific large language model instance""" + pass + + @property + def llm(self) -> BaseChatModel: + """Return the langchain LLM instance""" + return self._llm + +class OpenAIvLLM(BaseLLM): + def _init_llm(self) -> VLLMOpenAI: + return VLLMOpenAI( + openai_api_key=self.config.api_key or 'Empty', + openai_api_base=self.config.api_base_url, + model_name=self.config.model_name, + streaming=True, + **self.config.additional_params, + ) + +class OpenAIAzureLLM(BaseLLM): + def _init_llm(self) -> AzureChatOpenAI: + api_version = self.config.additional_params.get("api_version") + deployment_name = self.config.additional_params.get("deployment_name") + if api_version: + self.config.additional_params.pop("api_version") + if deployment_name: + self.config.additional_params.pop("deployment_name") + return AzureChatOpenAI( + azure_endpoint=self.config.api_base_url, + api_key=self.config.api_key or 'Empty', + model_name=self.config.model_name, + api_version=api_version, + deployment_name=deployment_name, + streaming=True, + **self.config.additional_params, + ) +class OpenAILLM(BaseLLM): + def _init_llm(self) -> BaseChatModel: + return BaseChatOpenAI( + model=self.config.model_name, + api_key=self.config.api_key or 'Empty', + base_url=self.config.api_base_url, + stream_usage=True, + **self.config.additional_params, + ) + + def generate(self, prompt: str) -> str: + return self.llm.invoke(prompt) + + +class LLMFactory: + """Large Language Model Factory Class""" + + _llm_types: Dict[str, Type[BaseLLM]] = { + "openai": OpenAILLM, + "tongyi": OpenAILLM, + "vllm": OpenAIvLLM, + "azure": OpenAIAzureLLM, + } + + @classmethod + @lru_cache(maxsize=32) + def create_llm(cls, config: LLMConfig) -> BaseLLM: + llm_class = cls._llm_types.get(config.model_type) + if not llm_class: + raise ValueError(f"Unsupported LLM type: {config.model_type}") + return llm_class(config) + + @classmethod + def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]): + """Register new model type""" + cls._llm_types[model_type] = llm_class + + +# todo +""" def get_llm_config(aimodel: AiModelDetail) -> LLMConfig: + config = LLMConfig( + model_type="openai", + model_name=aimodel.name, + api_key=aimodel.api_key, + api_base_url=aimodel.endpoint, + additional_params={"temperature": aimodel.temperature} + ) + return config """ + + +async def get_default_config() -> LLMConfig: + with Session(engine) as session: + db_model = session.exec( + select(AiModelDetail).where(AiModelDetail.default_model == True) + ).first() + if not db_model: + raise Exception("The system default model has not been set") + + additional_params = {} + if db_model.config: + try: + config_raw = json.loads(db_model.config) + additional_params = {item["key"]: prepare_model_arg(item.get('val')) for item in config_raw if "key" in item and "val" in item} + except Exception: + pass + if not db_model.api_domain.startswith("http"): + db_model.api_domain = await sqlbot_decrypt(db_model.api_domain) + if db_model.api_key: + db_model.api_key = await sqlbot_decrypt(db_model.api_key) + + + # 构造 LLMConfig + return LLMConfig( + model_id=db_model.id, + model_type="openai" if db_model.protocol == 1 else "vllm", + model_name=db_model.base_model, + api_key=db_model.api_key, + api_base_url=db_model.api_domain, + additional_params=additional_params, + )