Files
SQLBot/backend/apps/terminology/curd/terminology.py

431 lines
14 KiB
Python
Raw Normal View History

2025-09-08 16:36:17 +08:00
import datetime
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from xml.dom.minidom import parseString
import dicttoxml
from sqlalchemy import and_, or_, select, func, delete, update, union
from sqlalchemy import create_engine, text
from sqlalchemy.orm import aliased
from sqlalchemy.orm import sessionmaker
from apps.ai_model.embedding import EmbeddingModelCache
from apps.template.generate_chart.generator import get_base_terminology_template
from apps.terminology.models.terminology_model import Terminology, TerminologyInfo
from common.core.config import settings
from common.core.deps import SessionDep, Trans
executor = ThreadPoolExecutor(max_workers=200)
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
oid: Optional[int] = 1):
_list: List[TerminologyInfo] = []
child = aliased(Terminology)
current_page = max(1, current_page)
page_size = max(10, page_size)
total_count = 0
total_pages = 0
if name and name.strip() != "":
keyword_pattern = f"%{name.strip()}%"
# 步骤1先找到所有匹配的节点ID无论是父节点还是子节点
matched_ids_subquery = (
select(Terminology.id)
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件
.subquery()
)
# 步骤2找到这些匹配节点的所有父节点包括自身如果是父节点
parent_ids_subquery = (
select(Terminology.id)
.where(
(Terminology.id.in_(matched_ids_subquery)) |
(Terminology.id.in_(
select(Terminology.pid)
.where(Terminology.id.in_(matched_ids_subquery))
.where(Terminology.pid.isnot(None))
))
)
.where(Terminology.pid.is_(None)) # 只取父节点
)
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()
total_pages = (total_count + page_size - 1) // page_size
if current_page > total_pages:
current_page = 1
# 步骤3获取分页后的父节点ID
paginated_parent_ids = (
parent_ids_subquery
.order_by(Terminology.create_time.desc())
.offset((current_page - 1) * page_size)
.limit(page_size)
.subquery()
)
# 步骤4获取这些父节点的childrenNames
children_subquery = (
select(
child.pid,
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
)
.where(child.pid.isnot(None))
.group_by(child.pid)
.subquery()
)
# 主查询
stmt = (
select(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
children_subquery.c.other_words
)
.outerjoin(
children_subquery,
Terminology.id == children_subquery.c.pid
)
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
.order_by(Terminology.create_time.desc())
)
else:
parent_ids_subquery = (
select(Terminology.id)
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点
)
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()
total_pages = (total_count + page_size - 1) // page_size
if current_page > total_pages:
current_page = 1
paginated_parent_ids = (
parent_ids_subquery
.order_by(Terminology.create_time.desc())
.offset((current_page - 1) * page_size)
.limit(page_size)
.subquery()
)
stmt = (
select(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
)
.outerjoin(child, and_(Terminology.id == child.pid))
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
.group_by(Terminology.id, Terminology.word)
.order_by(Terminology.create_time.desc())
)
result = session.execute(stmt)
for row in result:
_list.append(TerminologyInfo(
id=row.id,
word=row.word,
create_time=row.create_time,
description=row.description,
other_words=row.other_words if row.other_words else [],
))
return current_page, page_size, total_count, total_pages, _list
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
create_time = datetime.datetime.now()
parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid)
words = [info.word]
for child in info.other_words:
if child in words:
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
else:
words.append(child)
exists = session.query(
session.query(Terminology).filter(and_(Terminology.word.in_(words), Terminology.oid == oid)).exists()).scalar()
if exists:
raise Exception(trans("i18n_terminology.exists_in_db"))
result = Terminology(**parent.model_dump())
session.add(parent)
session.flush()
session.refresh(parent)
result.id = parent.id
session.commit()
_list: List[Terminology] = []
if info.other_words:
for other_word in info.other_words:
if other_word.strip() == "":
continue
_list.append(
Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid))
session.bulk_save_objects(_list)
session.flush()
session.commit()
# embedding
run_save_embeddings([result.id])
return result.id
def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
count = session.query(Terminology).filter(
Terminology.oid == oid,
Terminology.id == info.id
).count()
if count == 0:
raise Exception(trans('i18n_terminology.terminology_not_exists'))
words = [info.word]
for child in info.other_words:
if child in words:
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
else:
words.append(child)
exists = session.query(
session.query(Terminology).filter(
Terminology.word.in_(words),
Terminology.oid == oid,
or_(
Terminology.pid != info.id,
and_(Terminology.pid.is_(None), Terminology.id != info.id)
),
Terminology.id != info.id
).exists()).scalar()
if exists:
raise Exception(trans("i18n_terminology.exists_in_db"))
stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
word=info.word,
description=info.description,
)
session.execute(stmt)
session.commit()
stmt = delete(Terminology).where(and_(Terminology.pid == info.id))
session.execute(stmt)
session.commit()
create_time = datetime.datetime.now()
_list: List[Terminology] = []
if info.other_words:
for other_word in info.other_words:
if other_word.strip() == "":
continue
_list.append(
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid))
session.bulk_save_objects(_list)
session.flush()
session.commit()
# embedding
run_save_embeddings([info.id])
return info.id
def delete_terminology(session: SessionDep, ids: list[int]):
stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids)))
session.execute(stmt)
session.commit()
def run_save_embeddings(ids: List[int]):
executor.submit(save_embeddings, ids)
def fill_empty_embeddings():
executor.submit(run_fill_empty_embeddings)
def run_fill_empty_embeddings():
if not settings.EMBEDDING_ENABLED:
return
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
session_maker = sessionmaker(bind=engine)
session = session_maker()
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
combined_stmt = union(stmt1, stmt2)
results = session.execute(combined_stmt).scalars().all()
save_embeddings(results)
def save_embeddings(ids: List[int]):
if not settings.EMBEDDING_ENABLED:
return
if not ids or len(ids) == 0:
return
try:
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
session_maker = sessionmaker(bind=engine)
session = session_maker()
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
_words_list = [item.word for item in _list]
model = EmbeddingModelCache.get_model()
results = model.embed_documents(_words_list)
for index in range(len(results)):
item = results[index]
stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item)
session.execute(stmt)
session.commit()
except Exception:
traceback.print_exc()
embedding_sql = f"""
SELECT id, pid, word, similarity
FROM
(SELECT id, pid, word, oid,
( 1 - (embedding <=> :embedding_array) ) AS similarity
FROM terminology AS child
) TEMP
WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid
ORDER BY similarity DESC
LIMIT {settings.EMBEDDING_TOP_COUNT}
"""
def select_terminology_by_word(session: SessionDep, word: str, oid: int):
if word.strip() == "":
return []
_list: List[Terminology] = []
stmt = (
select(
Terminology.id,
Terminology.pid,
Terminology.word,
)
.where(
and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid)
)
)
results = session.execute(stmt, {'sentence': word}).fetchall()
for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
if settings.EMBEDDING_ENABLED:
try:
model = EmbeddingModelCache.get_model()
embedding = model.embed_query(word)
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
except Exception:
traceback.print_exc()
_map: dict = {}
_ids: list[int] = []
for row in _list:
if row.id in _ids or row.pid in _ids:
continue
if row.pid is not None:
_ids.append(row.pid)
else:
_ids.append(row.id)
if len(_ids) == 0:
return []
t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter(
or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all()
for row in t_list:
pid = str(row.pid) if row.pid is not None else str(row.id)
if _map.get(pid) is None:
_map[pid] = {'words': [], 'description': row.description}
_map[pid]['words'].append(row.word)
_results: list[dict] = []
for key in _map.keys():
_results.append(_map.get(key))
return _results
def get_example():
_obj = {
'terminologies': [
{'words': ['GDP', '国内生产总值'],
'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'},
]
}
return to_xml_string(_obj, 'example')
def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
dicttoxml.LOG.setLevel(logging.ERROR)
xml = dicttoxml.dicttoxml(_dict,
cdata=['word', 'description'],
custom_root=root,
item_func=item_name_func,
xml_declaration=False,
encoding='utf-8',
attr_type=False).decode('utf-8')
pretty_xml = parseString(xml).toprettyxml()
if pretty_xml.startswith('<?xml'):
end_index = pretty_xml.find('>') + 1
pretty_xml = pretty_xml[end_index:].lstrip()
# 替换所有 XML 转义字符
escape_map = {
'&lt;': '<',
'&gt;': '>',
'&amp;': '&',
'&quot;': '"',
'&apos;': "'"
}
for escaped, original in escape_map.items():
pretty_xml = pretty_xml.replace(escaped, original)
return pretty_xml
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
if not oid:
oid = 1
_results = select_terminology_by_word(session, question, oid)
if _results and len(_results) > 0:
terminology = to_xml_string(_results)
template = get_base_terminology_template().format(terminologies=terminology)
return template
else:
return ''