Add File
This commit is contained in:
170
backend/apps/ai_model/model_factory.py
Normal file
170
backend/apps/ai_model/model_factory.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from functools import lru_cache
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, Type
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from apps.ai_model.openai.llm import BaseChatOpenAI
|
||||
from apps.system.models.system_model import AiModelDetail
|
||||
from common.core.db import engine
|
||||
from common.utils.crypto import sqlbot_decrypt
|
||||
from common.utils.utils import prepare_model_arg
|
||||
from langchain_community.llms import VLLMOpenAI
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
# from langchain_community.llms import Tongyi, VLLM
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Base configuration class for large language models"""
|
||||
model_id: Optional[int] = None
|
||||
model_type: str # Model type: openai/tongyi/vllm etc.
|
||||
model_name: str # Specific model name
|
||||
api_key: Optional[str] = None
|
||||
api_base_url: Optional[str] = None
|
||||
additional_params: Dict[str, Any] = {}
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
def __hash__(self):
|
||||
if hasattr(self, 'additional_params') and isinstance(self.additional_params, dict):
|
||||
hashable_params = frozenset((k, tuple(v) if isinstance(v, (list, dict)) else v)
|
||||
for k, v in self.additional_params.items())
|
||||
else:
|
||||
hashable_params = None
|
||||
|
||||
return hash((
|
||||
self.model_id,
|
||||
self.model_type,
|
||||
self.model_name,
|
||||
self.api_key,
|
||||
self.api_base_url,
|
||||
hashable_params
|
||||
))
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""Abstract base class for large language models"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._llm = self._init_llm()
|
||||
|
||||
@abstractmethod
|
||||
def _init_llm(self) -> BaseChatModel:
|
||||
"""Initialize specific large language model instance"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def llm(self) -> BaseChatModel:
|
||||
"""Return the langchain LLM instance"""
|
||||
return self._llm
|
||||
|
||||
class OpenAIvLLM(BaseLLM):
|
||||
def _init_llm(self) -> VLLMOpenAI:
|
||||
return VLLMOpenAI(
|
||||
openai_api_key=self.config.api_key or 'Empty',
|
||||
openai_api_base=self.config.api_base_url,
|
||||
model_name=self.config.model_name,
|
||||
streaming=True,
|
||||
**self.config.additional_params,
|
||||
)
|
||||
|
||||
class OpenAIAzureLLM(BaseLLM):
|
||||
def _init_llm(self) -> AzureChatOpenAI:
|
||||
api_version = self.config.additional_params.get("api_version")
|
||||
deployment_name = self.config.additional_params.get("deployment_name")
|
||||
if api_version:
|
||||
self.config.additional_params.pop("api_version")
|
||||
if deployment_name:
|
||||
self.config.additional_params.pop("deployment_name")
|
||||
return AzureChatOpenAI(
|
||||
azure_endpoint=self.config.api_base_url,
|
||||
api_key=self.config.api_key or 'Empty',
|
||||
model_name=self.config.model_name,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
streaming=True,
|
||||
**self.config.additional_params,
|
||||
)
|
||||
class OpenAILLM(BaseLLM):
|
||||
def _init_llm(self) -> BaseChatModel:
|
||||
return BaseChatOpenAI(
|
||||
model=self.config.model_name,
|
||||
api_key=self.config.api_key or 'Empty',
|
||||
base_url=self.config.api_base_url,
|
||||
stream_usage=True,
|
||||
**self.config.additional_params,
|
||||
)
|
||||
|
||||
def generate(self, prompt: str) -> str:
|
||||
return self.llm.invoke(prompt)
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""Large Language Model Factory Class"""
|
||||
|
||||
_llm_types: Dict[str, Type[BaseLLM]] = {
|
||||
"openai": OpenAILLM,
|
||||
"tongyi": OpenAILLM,
|
||||
"vllm": OpenAIvLLM,
|
||||
"azure": OpenAIAzureLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def create_llm(cls, config: LLMConfig) -> BaseLLM:
|
||||
llm_class = cls._llm_types.get(config.model_type)
|
||||
if not llm_class:
|
||||
raise ValueError(f"Unsupported LLM type: {config.model_type}")
|
||||
return llm_class(config)
|
||||
|
||||
@classmethod
|
||||
def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
|
||||
"""Register new model type"""
|
||||
cls._llm_types[model_type] = llm_class
|
||||
|
||||
|
||||
# todo
|
||||
""" def get_llm_config(aimodel: AiModelDetail) -> LLMConfig:
|
||||
config = LLMConfig(
|
||||
model_type="openai",
|
||||
model_name=aimodel.name,
|
||||
api_key=aimodel.api_key,
|
||||
api_base_url=aimodel.endpoint,
|
||||
additional_params={"temperature": aimodel.temperature}
|
||||
)
|
||||
return config """
|
||||
|
||||
|
||||
async def get_default_config() -> LLMConfig:
|
||||
with Session(engine) as session:
|
||||
db_model = session.exec(
|
||||
select(AiModelDetail).where(AiModelDetail.default_model == True)
|
||||
).first()
|
||||
if not db_model:
|
||||
raise Exception("The system default model has not been set")
|
||||
|
||||
additional_params = {}
|
||||
if db_model.config:
|
||||
try:
|
||||
config_raw = json.loads(db_model.config)
|
||||
additional_params = {item["key"]: prepare_model_arg(item.get('val')) for item in config_raw if "key" in item and "val" in item}
|
||||
except Exception:
|
||||
pass
|
||||
if not db_model.api_domain.startswith("http"):
|
||||
db_model.api_domain = await sqlbot_decrypt(db_model.api_domain)
|
||||
if db_model.api_key:
|
||||
db_model.api_key = await sqlbot_decrypt(db_model.api_key)
|
||||
|
||||
|
||||
# 构造 LLMConfig
|
||||
return LLMConfig(
|
||||
model_id=db_model.id,
|
||||
model_type="openai" if db_model.protocol == 1 else "vllm",
|
||||
model_name=db_model.base_model,
|
||||
api_key=db_model.api_key,
|
||||
api_base_url=db_model.api_domain,
|
||||
additional_params=additional_params,
|
||||
)
|
||||
Reference in New Issue
Block a user