diff --git a/backend/apps/ai_model/embedding.py b/backend/apps/ai_model/embedding.py new file mode 100644 index 0000000..315c846 --- /dev/null +++ b/backend/apps/ai_model/embedding.py @@ -0,0 +1,63 @@ +import os.path +import threading +from typing import Optional + +from langchain_core.embeddings import Embeddings +from langchain_huggingface import HuggingFaceEmbeddings +from pydantic import BaseModel + +from common.core.config import settings + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class EmbeddingModelInfo(BaseModel): + folder: str + name: str + device: str = 'cpu' + + +local_embedding_model = EmbeddingModelInfo(folder=settings.LOCAL_MODEL_PATH, + name=os.path.join(settings.LOCAL_MODEL_PATH, 'embedding', + "shibing624_text2vec-base-chinese")) + +_lock = threading.Lock() +locks = {} + +_embedding_model: dict[str, Optional[Embeddings]] = {} + + +class EmbeddingModelCache: + + @staticmethod + def _new_instance(config: EmbeddingModelInfo = local_embedding_model): + return HuggingFaceEmbeddings(model_name=config.name, cache_folder=config.folder, + model_kwargs={'device': config.device}, + encode_kwargs={'normalize_embeddings': True} + ) + + @staticmethod + def _get_lock(key: str = settings.DEFAULT_EMBEDDING_MODEL): + lock = locks.get(key) + if lock is None: + with _lock: + lock = locks.get(key) + if lock is None: + lock = threading.Lock() + locks[key] = lock + + return lock + + @staticmethod + def get_model(key: str = settings.DEFAULT_EMBEDDING_MODEL, + config: EmbeddingModelInfo = local_embedding_model) -> Embeddings: + model_instance = _embedding_model.get(key) + if model_instance is None: + lock = EmbeddingModelCache._get_lock(key) + with lock: + model_instance = _embedding_model.get(key) + if model_instance is None: + model_instance = EmbeddingModelCache._new_instance(config) + _embedding_model[key] = model_instance + + return model_instance