Files
SQLBot/backend/apps/ai_model/model_factory.py
2025-09-08 16:36:25 +08:00

171 lines
5.6 KiB
Python

from functools import lru_cache
import json
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Type
from langchain.chat_models.base import BaseChatModel
from pydantic import BaseModel
from sqlmodel import Session, select
from apps.ai_model.openai.llm import BaseChatOpenAI
from apps.system.models.system_model import AiModelDetail
from common.core.db import engine
from common.utils.crypto import sqlbot_decrypt
from common.utils.utils import prepare_model_arg
from langchain_community.llms import VLLMOpenAI
from langchain_openai import AzureChatOpenAI
# from langchain_community.llms import Tongyi, VLLM
class LLMConfig(BaseModel):
"""Base configuration class for large language models"""
model_id: Optional[int] = None
model_type: str # Model type: openai/tongyi/vllm etc.
model_name: str # Specific model name
api_key: Optional[str] = None
api_base_url: Optional[str] = None
additional_params: Dict[str, Any] = {}
class Config:
frozen = True
def __hash__(self):
if hasattr(self, 'additional_params') and isinstance(self.additional_params, dict):
hashable_params = frozenset((k, tuple(v) if isinstance(v, (list, dict)) else v)
for k, v in self.additional_params.items())
else:
hashable_params = None
return hash((
self.model_id,
self.model_type,
self.model_name,
self.api_key,
self.api_base_url,
hashable_params
))
class BaseLLM(ABC):
"""Abstract base class for large language models"""
def __init__(self, config: LLMConfig):
self.config = config
self._llm = self._init_llm()
@abstractmethod
def _init_llm(self) -> BaseChatModel:
"""Initialize specific large language model instance"""
pass
@property
def llm(self) -> BaseChatModel:
"""Return the langchain LLM instance"""
return self._llm
class OpenAIvLLM(BaseLLM):
def _init_llm(self) -> VLLMOpenAI:
return VLLMOpenAI(
openai_api_key=self.config.api_key or 'Empty',
openai_api_base=self.config.api_base_url,
model_name=self.config.model_name,
streaming=True,
**self.config.additional_params,
)
class OpenAIAzureLLM(BaseLLM):
def _init_llm(self) -> AzureChatOpenAI:
api_version = self.config.additional_params.get("api_version")
deployment_name = self.config.additional_params.get("deployment_name")
if api_version:
self.config.additional_params.pop("api_version")
if deployment_name:
self.config.additional_params.pop("deployment_name")
return AzureChatOpenAI(
azure_endpoint=self.config.api_base_url,
api_key=self.config.api_key or 'Empty',
model_name=self.config.model_name,
api_version=api_version,
deployment_name=deployment_name,
streaming=True,
**self.config.additional_params,
)
class OpenAILLM(BaseLLM):
def _init_llm(self) -> BaseChatModel:
return BaseChatOpenAI(
model=self.config.model_name,
api_key=self.config.api_key or 'Empty',
base_url=self.config.api_base_url,
stream_usage=True,
**self.config.additional_params,
)
def generate(self, prompt: str) -> str:
return self.llm.invoke(prompt)
class LLMFactory:
"""Large Language Model Factory Class"""
_llm_types: Dict[str, Type[BaseLLM]] = {
"openai": OpenAILLM,
"tongyi": OpenAILLM,
"vllm": OpenAIvLLM,
"azure": OpenAIAzureLLM,
}
@classmethod
@lru_cache(maxsize=32)
def create_llm(cls, config: LLMConfig) -> BaseLLM:
llm_class = cls._llm_types.get(config.model_type)
if not llm_class:
raise ValueError(f"Unsupported LLM type: {config.model_type}")
return llm_class(config)
@classmethod
def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
"""Register new model type"""
cls._llm_types[model_type] = llm_class
# todo
""" def get_llm_config(aimodel: AiModelDetail) -> LLMConfig:
config = LLMConfig(
model_type="openai",
model_name=aimodel.name,
api_key=aimodel.api_key,
api_base_url=aimodel.endpoint,
additional_params={"temperature": aimodel.temperature}
)
return config """
async def get_default_config() -> LLMConfig:
with Session(engine) as session:
db_model = session.exec(
select(AiModelDetail).where(AiModelDetail.default_model == True)
).first()
if not db_model:
raise Exception("The system default model has not been set")
additional_params = {}
if db_model.config:
try:
config_raw = json.loads(db_model.config)
additional_params = {item["key"]: prepare_model_arg(item.get('val')) for item in config_raw if "key" in item and "val" in item}
except Exception:
pass
if not db_model.api_domain.startswith("http"):
db_model.api_domain = await sqlbot_decrypt(db_model.api_domain)
if db_model.api_key:
db_model.api_key = await sqlbot_decrypt(db_model.api_key)
# 构造 LLMConfig
return LLMConfig(
model_id=db_model.id,
model_type="openai" if db_model.protocol == 1 else "vllm",
model_name=db_model.base_model,
api_key=db_model.api_key,
api_base_url=db_model.api_domain,
additional_params=additional_params,
)