431 lines
14 KiB
Python
431 lines
14 KiB
Python
|
|
import datetime
|
|||
|
|
import logging
|
|||
|
|
import traceback
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from xml.dom.minidom import parseString
|
|||
|
|
|
|||
|
|
import dicttoxml
|
|||
|
|
from sqlalchemy import and_, or_, select, func, delete, update, union
|
|||
|
|
from sqlalchemy import create_engine, text
|
|||
|
|
from sqlalchemy.orm import aliased
|
|||
|
|
from sqlalchemy.orm import sessionmaker
|
|||
|
|
|
|||
|
|
from apps.ai_model.embedding import EmbeddingModelCache
|
|||
|
|
from apps.template.generate_chart.generator import get_base_terminology_template
|
|||
|
|
from apps.terminology.models.terminology_model import Terminology, TerminologyInfo
|
|||
|
|
from common.core.config import settings
|
|||
|
|
from common.core.deps import SessionDep, Trans
|
|||
|
|
|
|||
|
|
executor = ThreadPoolExecutor(max_workers=200)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
|
|||
|
|
oid: Optional[int] = 1):
|
|||
|
|
_list: List[TerminologyInfo] = []
|
|||
|
|
|
|||
|
|
child = aliased(Terminology)
|
|||
|
|
|
|||
|
|
current_page = max(1, current_page)
|
|||
|
|
page_size = max(10, page_size)
|
|||
|
|
|
|||
|
|
total_count = 0
|
|||
|
|
total_pages = 0
|
|||
|
|
|
|||
|
|
if name and name.strip() != "":
|
|||
|
|
keyword_pattern = f"%{name.strip()}%"
|
|||
|
|
# 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
|
|||
|
|
matched_ids_subquery = (
|
|||
|
|
select(Terminology.id)
|
|||
|
|
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件
|
|||
|
|
.subquery()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 步骤2:找到这些匹配节点的所有父节点(包括自身如果是父节点)
|
|||
|
|
parent_ids_subquery = (
|
|||
|
|
select(Terminology.id)
|
|||
|
|
.where(
|
|||
|
|
(Terminology.id.in_(matched_ids_subquery)) |
|
|||
|
|
(Terminology.id.in_(
|
|||
|
|
select(Terminology.pid)
|
|||
|
|
.where(Terminology.id.in_(matched_ids_subquery))
|
|||
|
|
.where(Terminology.pid.isnot(None))
|
|||
|
|
))
|
|||
|
|
)
|
|||
|
|
.where(Terminology.pid.is_(None)) # 只取父节点
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
|
|||
|
|
total_count = session.execute(count_stmt).scalar()
|
|||
|
|
total_pages = (total_count + page_size - 1) // page_size
|
|||
|
|
|
|||
|
|
if current_page > total_pages:
|
|||
|
|
current_page = 1
|
|||
|
|
|
|||
|
|
# 步骤3:获取分页后的父节点ID
|
|||
|
|
paginated_parent_ids = (
|
|||
|
|
parent_ids_subquery
|
|||
|
|
.order_by(Terminology.create_time.desc())
|
|||
|
|
.offset((current_page - 1) * page_size)
|
|||
|
|
.limit(page_size)
|
|||
|
|
.subquery()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 步骤4:获取这些父节点的childrenNames
|
|||
|
|
children_subquery = (
|
|||
|
|
select(
|
|||
|
|
child.pid,
|
|||
|
|
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
|
|||
|
|
)
|
|||
|
|
.where(child.pid.isnot(None))
|
|||
|
|
.group_by(child.pid)
|
|||
|
|
.subquery()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 主查询
|
|||
|
|
stmt = (
|
|||
|
|
select(
|
|||
|
|
Terminology.id,
|
|||
|
|
Terminology.word,
|
|||
|
|
Terminology.create_time,
|
|||
|
|
Terminology.description,
|
|||
|
|
children_subquery.c.other_words
|
|||
|
|
)
|
|||
|
|
.outerjoin(
|
|||
|
|
children_subquery,
|
|||
|
|
Terminology.id == children_subquery.c.pid
|
|||
|
|
)
|
|||
|
|
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
|
|||
|
|
.order_by(Terminology.create_time.desc())
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
parent_ids_subquery = (
|
|||
|
|
select(Terminology.id)
|
|||
|
|
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点
|
|||
|
|
)
|
|||
|
|
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
|
|||
|
|
total_count = session.execute(count_stmt).scalar()
|
|||
|
|
total_pages = (total_count + page_size - 1) // page_size
|
|||
|
|
|
|||
|
|
if current_page > total_pages:
|
|||
|
|
current_page = 1
|
|||
|
|
|
|||
|
|
paginated_parent_ids = (
|
|||
|
|
parent_ids_subquery
|
|||
|
|
.order_by(Terminology.create_time.desc())
|
|||
|
|
.offset((current_page - 1) * page_size)
|
|||
|
|
.limit(page_size)
|
|||
|
|
.subquery()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
stmt = (
|
|||
|
|
select(
|
|||
|
|
Terminology.id,
|
|||
|
|
Terminology.word,
|
|||
|
|
Terminology.create_time,
|
|||
|
|
Terminology.description,
|
|||
|
|
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
|
|||
|
|
)
|
|||
|
|
.outerjoin(child, and_(Terminology.id == child.pid))
|
|||
|
|
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
|
|||
|
|
.group_by(Terminology.id, Terminology.word)
|
|||
|
|
.order_by(Terminology.create_time.desc())
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result = session.execute(stmt)
|
|||
|
|
|
|||
|
|
for row in result:
|
|||
|
|
_list.append(TerminologyInfo(
|
|||
|
|
id=row.id,
|
|||
|
|
word=row.word,
|
|||
|
|
create_time=row.create_time,
|
|||
|
|
description=row.description,
|
|||
|
|
other_words=row.other_words if row.other_words else [],
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return current_page, page_size, total_count, total_pages, _list
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
|
|||
|
|
create_time = datetime.datetime.now()
|
|||
|
|
parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid)
|
|||
|
|
|
|||
|
|
words = [info.word]
|
|||
|
|
for child in info.other_words:
|
|||
|
|
if child in words:
|
|||
|
|
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
|
|||
|
|
else:
|
|||
|
|
words.append(child)
|
|||
|
|
|
|||
|
|
exists = session.query(
|
|||
|
|
session.query(Terminology).filter(and_(Terminology.word.in_(words), Terminology.oid == oid)).exists()).scalar()
|
|||
|
|
if exists:
|
|||
|
|
raise Exception(trans("i18n_terminology.exists_in_db"))
|
|||
|
|
|
|||
|
|
result = Terminology(**parent.model_dump())
|
|||
|
|
|
|||
|
|
session.add(parent)
|
|||
|
|
session.flush()
|
|||
|
|
session.refresh(parent)
|
|||
|
|
|
|||
|
|
result.id = parent.id
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
_list: List[Terminology] = []
|
|||
|
|
if info.other_words:
|
|||
|
|
for other_word in info.other_words:
|
|||
|
|
if other_word.strip() == "":
|
|||
|
|
continue
|
|||
|
|
_list.append(
|
|||
|
|
Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid))
|
|||
|
|
session.bulk_save_objects(_list)
|
|||
|
|
session.flush()
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
# embedding
|
|||
|
|
run_save_embeddings([result.id])
|
|||
|
|
|
|||
|
|
return result.id
|
|||
|
|
|
|||
|
|
|
|||
|
|
def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
|
|||
|
|
count = session.query(Terminology).filter(
|
|||
|
|
Terminology.oid == oid,
|
|||
|
|
Terminology.id == info.id
|
|||
|
|
).count()
|
|||
|
|
if count == 0:
|
|||
|
|
raise Exception(trans('i18n_terminology.terminology_not_exists'))
|
|||
|
|
|
|||
|
|
words = [info.word]
|
|||
|
|
for child in info.other_words:
|
|||
|
|
if child in words:
|
|||
|
|
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
|
|||
|
|
else:
|
|||
|
|
words.append(child)
|
|||
|
|
|
|||
|
|
exists = session.query(
|
|||
|
|
session.query(Terminology).filter(
|
|||
|
|
Terminology.word.in_(words),
|
|||
|
|
Terminology.oid == oid,
|
|||
|
|
or_(
|
|||
|
|
Terminology.pid != info.id,
|
|||
|
|
and_(Terminology.pid.is_(None), Terminology.id != info.id)
|
|||
|
|
),
|
|||
|
|
Terminology.id != info.id
|
|||
|
|
).exists()).scalar()
|
|||
|
|
if exists:
|
|||
|
|
raise Exception(trans("i18n_terminology.exists_in_db"))
|
|||
|
|
|
|||
|
|
stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
|
|||
|
|
word=info.word,
|
|||
|
|
description=info.description,
|
|||
|
|
)
|
|||
|
|
session.execute(stmt)
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
stmt = delete(Terminology).where(and_(Terminology.pid == info.id))
|
|||
|
|
session.execute(stmt)
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
create_time = datetime.datetime.now()
|
|||
|
|
_list: List[Terminology] = []
|
|||
|
|
if info.other_words:
|
|||
|
|
for other_word in info.other_words:
|
|||
|
|
if other_word.strip() == "":
|
|||
|
|
continue
|
|||
|
|
_list.append(
|
|||
|
|
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid))
|
|||
|
|
session.bulk_save_objects(_list)
|
|||
|
|
session.flush()
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
# embedding
|
|||
|
|
run_save_embeddings([info.id])
|
|||
|
|
|
|||
|
|
return info.id
|
|||
|
|
|
|||
|
|
|
|||
|
|
def delete_terminology(session: SessionDep, ids: list[int]):
|
|||
|
|
stmt = delete(Terminology).where(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids)))
|
|||
|
|
session.execute(stmt)
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_save_embeddings(ids: List[int]):
|
|||
|
|
executor.submit(save_embeddings, ids)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fill_empty_embeddings():
|
|||
|
|
executor.submit(run_fill_empty_embeddings)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_fill_empty_embeddings():
|
|||
|
|
if not settings.EMBEDDING_ENABLED:
|
|||
|
|
return
|
|||
|
|
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
|||
|
|
session_maker = sessionmaker(bind=engine)
|
|||
|
|
session = session_maker()
|
|||
|
|
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
|
|||
|
|
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
|
|||
|
|
combined_stmt = union(stmt1, stmt2)
|
|||
|
|
results = session.execute(combined_stmt).scalars().all()
|
|||
|
|
save_embeddings(results)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_embeddings(ids: List[int]):
|
|||
|
|
if not settings.EMBEDDING_ENABLED:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
if not ids or len(ids) == 0:
|
|||
|
|
return
|
|||
|
|
try:
|
|||
|
|
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
|||
|
|
session_maker = sessionmaker(bind=engine)
|
|||
|
|
session = session_maker()
|
|||
|
|
|
|||
|
|
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
|
|||
|
|
|
|||
|
|
_words_list = [item.word for item in _list]
|
|||
|
|
|
|||
|
|
model = EmbeddingModelCache.get_model()
|
|||
|
|
|
|||
|
|
results = model.embed_documents(_words_list)
|
|||
|
|
|
|||
|
|
for index in range(len(results)):
|
|||
|
|
item = results[index]
|
|||
|
|
stmt = update(Terminology).where(and_(Terminology.id == _list[index].id)).values(embedding=item)
|
|||
|
|
session.execute(stmt)
|
|||
|
|
session.commit()
|
|||
|
|
|
|||
|
|
except Exception:
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
|
|||
|
|
embedding_sql = f"""
|
|||
|
|
SELECT id, pid, word, similarity
|
|||
|
|
FROM
|
|||
|
|
(SELECT id, pid, word, oid,
|
|||
|
|
( 1 - (embedding <=> :embedding_array) ) AS similarity
|
|||
|
|
FROM terminology AS child
|
|||
|
|
) TEMP
|
|||
|
|
WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid
|
|||
|
|
ORDER BY similarity DESC
|
|||
|
|
LIMIT {settings.EMBEDDING_TOP_COUNT}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def select_terminology_by_word(session: SessionDep, word: str, oid: int):
|
|||
|
|
if word.strip() == "":
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
_list: List[Terminology] = []
|
|||
|
|
|
|||
|
|
stmt = (
|
|||
|
|
select(
|
|||
|
|
Terminology.id,
|
|||
|
|
Terminology.pid,
|
|||
|
|
Terminology.word,
|
|||
|
|
)
|
|||
|
|
.where(
|
|||
|
|
and_(text(":sentence ILIKE '%' || word || '%'"), Terminology.oid == oid)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = session.execute(stmt, {'sentence': word}).fetchall()
|
|||
|
|
|
|||
|
|
for row in results:
|
|||
|
|
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
|
|||
|
|
|
|||
|
|
if settings.EMBEDDING_ENABLED:
|
|||
|
|
try:
|
|||
|
|
model = EmbeddingModelCache.get_model()
|
|||
|
|
|
|||
|
|
embedding = model.embed_query(word)
|
|||
|
|
|
|||
|
|
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
|
|||
|
|
|
|||
|
|
for row in results:
|
|||
|
|
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
|
|||
|
|
|
|||
|
|
except Exception:
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
_map: dict = {}
|
|||
|
|
_ids: list[int] = []
|
|||
|
|
for row in _list:
|
|||
|
|
if row.id in _ids or row.pid in _ids:
|
|||
|
|
continue
|
|||
|
|
if row.pid is not None:
|
|||
|
|
_ids.append(row.pid)
|
|||
|
|
else:
|
|||
|
|
_ids.append(row.id)
|
|||
|
|
|
|||
|
|
if len(_ids) == 0:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
t_list = session.query(Terminology.id, Terminology.pid, Terminology.word, Terminology.description).filter(
|
|||
|
|
or_(Terminology.id.in_(_ids), Terminology.pid.in_(_ids))).all()
|
|||
|
|
for row in t_list:
|
|||
|
|
pid = str(row.pid) if row.pid is not None else str(row.id)
|
|||
|
|
if _map.get(pid) is None:
|
|||
|
|
_map[pid] = {'words': [], 'description': row.description}
|
|||
|
|
_map[pid]['words'].append(row.word)
|
|||
|
|
|
|||
|
|
_results: list[dict] = []
|
|||
|
|
for key in _map.keys():
|
|||
|
|
_results.append(_map.get(key))
|
|||
|
|
|
|||
|
|
return _results
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_example():
|
|||
|
|
_obj = {
|
|||
|
|
'terminologies': [
|
|||
|
|
{'words': ['GDP', '国内生产总值'],
|
|||
|
|
'description': '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。'},
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
return to_xml_string(_obj, 'example')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
|
|||
|
|
item_name_func = lambda x: 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
|
|||
|
|
dicttoxml.LOG.setLevel(logging.ERROR)
|
|||
|
|
xml = dicttoxml.dicttoxml(_dict,
|
|||
|
|
cdata=['word', 'description'],
|
|||
|
|
custom_root=root,
|
|||
|
|
item_func=item_name_func,
|
|||
|
|
xml_declaration=False,
|
|||
|
|
encoding='utf-8',
|
|||
|
|
attr_type=False).decode('utf-8')
|
|||
|
|
pretty_xml = parseString(xml).toprettyxml()
|
|||
|
|
|
|||
|
|
if pretty_xml.startswith('<?xml'):
|
|||
|
|
end_index = pretty_xml.find('>') + 1
|
|||
|
|
pretty_xml = pretty_xml[end_index:].lstrip()
|
|||
|
|
|
|||
|
|
# 替换所有 XML 转义字符
|
|||
|
|
escape_map = {
|
|||
|
|
'<': '<',
|
|||
|
|
'>': '>',
|
|||
|
|
'&': '&',
|
|||
|
|
'"': '"',
|
|||
|
|
''': "'"
|
|||
|
|
}
|
|||
|
|
for escaped, original in escape_map.items():
|
|||
|
|
pretty_xml = pretty_xml.replace(escaped, original)
|
|||
|
|
|
|||
|
|
return pretty_xml
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
|
|||
|
|
if not oid:
|
|||
|
|
oid = 1
|
|||
|
|
_results = select_terminology_by_word(session, question, oid)
|
|||
|
|
if _results and len(_results) > 0:
|
|||
|
|
terminology = to_xml_string(_results)
|
|||
|
|
template = get_base_terminology_template().format(terminologies=terminology)
|
|||
|
|
return template
|
|||
|
|
else:
|
|||
|
|
return ''
|