This commit is contained in:
2025-09-08 16:35:59 +08:00
parent e16b7a873c
commit 3f14677c82

View File

@@ -0,0 +1,725 @@
import datetime
from typing import List
import orjson
import sqlparse
from sqlalchemy import and_, select, update
from sqlalchemy.orm import aliased
from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \
TypeEnum, OperationEnum, ChatRecordResult
from apps.datasource.models.datasource import CoreDatasource
from apps.system.crud.assistant import AssistantOutDsFactory
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
from common.utils.utils import extract_nested_json
def get_chat_record_by_id(session: SessionDep, record_id: int):
record: ChatRecord | None = None
stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type,
ChatRecord.ai_modal_id, ChatRecord.create_by).where(
and_(ChatRecord.id == record_id))
result = session.execute(stmt)
for r in result:
record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource,
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by)
return record
def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]:
oid = current_user.oid if current_user.oid is not None else 1
chart_list = session.query(Chat).filter(and_(Chat.create_by == current_user.id, Chat.oid == oid)).order_by(
Chat.create_time.desc()).all()
return chart_list
def rename_chat(session: SessionDep, rename_object: RenameChat) -> str:
chat = session.get(Chat, rename_object.id)
if not chat:
raise Exception(f"Chat with id {rename_object.id} not found")
chat.brief = rename_object.brief.strip()[:20]
session.add(chat)
session.flush()
session.refresh(chat)
brief = chat.brief
session.commit()
return brief
def delete_chat(session, chart_id) -> str:
chat = session.query(Chat).filter(Chat.id == chart_id).first()
if not chat:
return f'Chat with id {chart_id} has been deleted'
session.delete(chat)
session.commit()
return f'Chat with id {chart_id} has been deleted'
def get_chart_config(session: SessionDep, chart_record_id: int):
stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chart_record_id))
res = session.execute(stmt)
for row in res:
try:
return orjson.loads(row.chart)
except Exception:
pass
return {}
def get_last_execute_sql_error(session: SessionDep, chart_id: int):
stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by(
ChatRecord.create_time.desc()).limit(1)
res = session.execute(stmt).scalar()
if res:
try:
obj = orjson.loads(res)
if obj.get('type') and obj.get('type') == 'exec-sql-err':
return obj.get('traceback')
except Exception:
pass
return None
def get_chat_chart_data(session: SessionDep, chart_record_id: int):
stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chart_record_id))
res = session.execute(stmt)
for row in res:
try:
return orjson.loads(row.data)
except Exception:
pass
return []
def get_chat_predict_data(session: SessionDep, chart_record_id: int):
stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chart_record_id))
res = session.execute(stmt)
for row in res:
try:
return orjson.loads(row.predict_data)
except Exception:
pass
return []
def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_user: CurrentUser,
current_assistant: CurrentAssistant) -> ChatInfo:
return get_chat_with_records(session, chart_id, current_user, current_assistant, True)
dynamic_ds_types = [1, 3]
def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser,
current_assistant: CurrentAssistant, with_data: bool = False) -> ChatInfo:
chat = session.get(Chat, chart_id)
if not chat:
raise Exception(f"Chat with id {chart_id} not found")
chat_info = ChatInfo(**chat.model_dump())
if current_assistant and current_assistant.type in dynamic_ds_types:
out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant)
ds = out_ds_instance.get_ds(chat.datasource)
else:
ds = session.get(CoreDatasource, chat.datasource) if chat.datasource else None
if not ds:
chat_info.datasource_exists = False
chat_info.datasource_name = 'Datasource not exist'
else:
chat_info.datasource_exists = True
chat_info.datasource_name = ds.name
chat_info.ds_type = ds.type
sql_alias_log = aliased(ChatLog)
chart_alias_log = aliased(ChatLog)
analysis_alias_log = aliased(ChatLog)
predict_alias_log = aliased(ChatLog)
stmt = (select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time,
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql,
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
ChatRecord.recommended_question, ChatRecord.first_chat,
ChatRecord.finish, ChatRecord.error,
sql_alias_log.reasoning_content.label('sql_reasoning_content'),
chart_alias_log.reasoning_content.label('chart_reasoning_content'),
analysis_alias_log.reasoning_content.label('analysis_reasoning_content'),
predict_alias_log.reasoning_content.label('predict_reasoning_content')
)
.outerjoin(sql_alias_log, and_(sql_alias_log.pid == ChatRecord.id,
sql_alias_log.type == TypeEnum.CHAT,
sql_alias_log.operate == OperationEnum.GENERATE_SQL))
.outerjoin(chart_alias_log, and_(chart_alias_log.pid == ChatRecord.id,
chart_alias_log.type == TypeEnum.CHAT,
chart_alias_log.operate == OperationEnum.GENERATE_CHART))
.outerjoin(analysis_alias_log, and_(analysis_alias_log.pid == ChatRecord.id,
analysis_alias_log.type == TypeEnum.CHAT,
analysis_alias_log.operate == OperationEnum.ANALYSIS))
.outerjoin(predict_alias_log, and_(predict_alias_log.pid == ChatRecord.id,
predict_alias_log.type == TypeEnum.CHAT,
predict_alias_log.operate == OperationEnum.PREDICT_DATA))
.where(and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(
ChatRecord.create_time))
if with_data:
stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time,
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql,
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
ChatRecord.recommended_question, ChatRecord.first_chat,
ChatRecord.finish, ChatRecord.error, ChatRecord.data, ChatRecord.predict_data).where(
and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(
ChatRecord.create_time)
result = session.execute(stmt).all()
record_list: list[ChatRecordResult] = []
for row in result:
if not with_data:
record_list.append(
ChatRecordResult(id=row.id, chat_id=row.chat_id, create_time=row.create_time,
finish_time=row.finish_time,
question=row.question, sql_answer=row.sql_answer, sql=row.sql,
chart_answer=row.chart_answer, chart=row.chart,
analysis=row.analysis, predict=row.predict,
datasource_select_answer=row.datasource_select_answer,
analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id,
recommended_question=row.recommended_question, first_chat=row.first_chat,
finish=row.finish, error=row.error,
sql_reasoning_content=row.sql_reasoning_content,
chart_reasoning_content=row.chart_reasoning_content,
analysis_reasoning_content=row.analysis_reasoning_content,
predict_reasoning_content=row.predict_reasoning_content,
))
else:
record_list.append(
ChatRecordResult(id=row.id, chat_id=row.chat_id, create_time=row.create_time,
finish_time=row.finish_time,
question=row.question, sql_answer=row.sql_answer, sql=row.sql,
chart_answer=row.chart_answer, chart=row.chart,
analysis=row.analysis, predict=row.predict,
datasource_select_answer=row.datasource_select_answer,
analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id,
recommended_question=row.recommended_question, first_chat=row.first_chat,
finish=row.finish, error=row.error, data=row.data, predict_data=row.predict_data))
result = list(map(format_record, record_list))
chat_info.records = result
return chat_info
def format_record(record: ChatRecordResult):
_dict = record.model_dump()
if record.sql_answer and record.sql_answer.strip() != '' and record.sql_answer.strip()[0] == '{' and \
record.sql_answer.strip()[-1] == '}':
_obj = orjson.loads(record.sql_answer)
_dict['sql_answer'] = _obj.get('reasoning_content')
if record.sql_reasoning_content and record.sql_reasoning_content.strip() != '':
_dict['sql_answer'] = record.sql_reasoning_content
if record.chart_answer and record.chart_answer.strip() != '' and record.chart_answer.strip()[0] == '{' and \
record.chart_answer.strip()[-1] == '}':
_obj = orjson.loads(record.chart_answer)
_dict['chart_answer'] = _obj.get('reasoning_content')
if record.chart_reasoning_content and record.chart_reasoning_content.strip() != '':
_dict['chart_answer'] = record.chart_reasoning_content
if record.analysis and record.analysis.strip() != '' and record.analysis.strip()[0] == '{' and \
record.analysis.strip()[-1] == '}':
_obj = orjson.loads(record.analysis)
_dict['analysis_thinking'] = _obj.get('reasoning_content')
_dict['analysis'] = _obj.get('content')
if record.analysis_reasoning_content and record.analysis_reasoning_content.strip() != '':
_dict['analysis_thinking'] = record.analysis_reasoning_content
if record.predict and record.predict.strip() != '' and record.predict.strip()[0] == '{' and record.predict.strip()[
-1] == '}':
_obj = orjson.loads(record.predict)
_dict['predict'] = _obj.get('reasoning_content')
_dict['predict_content'] = _obj.get('content')
if record.predict_reasoning_content and record.predict_reasoning_content.strip() != '':
_dict['predict'] = record.predict_reasoning_content
if record.data and record.data.strip() != '':
try:
_obj = orjson.loads(record.data)
_dict['data'] = _obj
except Exception:
pass
if record.predict_data and record.predict_data.strip() != '':
try:
_obj = orjson.loads(record.predict_data)
_dict['predict_data'] = _obj
except Exception:
pass
if record.sql and record.sql.strip() != '':
try:
_dict['sql'] = sqlparse.format(record.sql, reindent=True)
except Exception:
pass
return _dict
def list_generate_sql_logs(session: SessionDep, chart_id: int) -> List[ChatLog]:
stmt = select(ChatLog).where(
and_(ChatLog.pid.in_(select(ChatRecord.id).where(and_(ChatRecord.chat_id == chart_id))),
ChatLog.type == TypeEnum.CHAT, ChatLog.operate == OperationEnum.GENERATE_SQL)).order_by(
ChatLog.start_time)
result = session.execute(stmt).all()
_list = []
for row in result:
for r in row:
_list.append(ChatLog(**r.model_dump()))
return _list
def list_generate_chart_logs(session: SessionDep, chart_id: int) -> List[ChatLog]:
stmt = select(ChatLog).where(
and_(ChatLog.pid.in_(select(ChatRecord.id).where(and_(ChatRecord.chat_id == chart_id))),
ChatLog.type == TypeEnum.CHAT, ChatLog.operate == OperationEnum.GENERATE_CHART)).order_by(
ChatLog.start_time)
result = session.execute(stmt).all()
_list = []
for row in result:
for r in row:
_list.append(ChatLog(**r.model_dump()))
return _list
def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat,
require_datasource: bool = True) -> ChatInfo:
if not create_chat_obj.datasource and require_datasource:
raise Exception("Datasource cannot be None")
if not create_chat_obj.question or create_chat_obj.question.strip() == '':
create_chat_obj.question = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat = Chat(create_time=datetime.datetime.now(),
create_by=current_user.id,
oid=current_user.oid if current_user.oid is not None else 1,
brief=create_chat_obj.question.strip()[:20],
origin=create_chat_obj.origin if create_chat_obj.origin is not None else 0)
ds: CoreDatasource | None = None
if create_chat_obj.datasource:
chat.datasource = create_chat_obj.datasource
ds = session.get(CoreDatasource, create_chat_obj.datasource)
if not ds:
raise Exception(f"Datasource with id {create_chat_obj.datasource} not found")
chat.engine_type = ds.type_name
else:
chat.engine_type = ''
chat_info = ChatInfo(**chat.model_dump())
session.add(chat)
session.flush()
session.refresh(chat)
chat_info.id = chat.id
session.commit()
if ds:
chat_info.datasource_exists = True
chat_info.datasource_name = ds.name
chat_info.ds_type = ds.type
if require_datasource and ds:
# generate first empty record
record = ChatRecord()
record.chat_id = chat.id
record.datasource = ds.id
record.engine_type = ds.type_name
record.first_chat = True
record.finish = True
record.create_time = datetime.datetime.now()
record.create_by = current_user.id
_record = ChatRecord(**record.model_dump())
session.add(record)
session.flush()
session.refresh(record)
_record.id = record.id
session.commit()
chat_info.records.append(_record)
return chat_info
def save_question(session: SessionDep, current_user: CurrentUser, question: ChatQuestion) -> ChatRecord:
if not question.chat_id:
raise Exception("ChatId cannot be None")
if not question.question or question.question.strip() == '':
raise Exception("Question cannot be Empty")
# chat = session.query(Chat).filter(Chat.id == question.chat_id).first()
chat: Chat = session.get(Chat, question.chat_id)
if not chat:
raise Exception(f"Chat with id {question.chat_id} not found")
record = ChatRecord()
record.question = question.question
record.chat_id = chat.id
record.create_time = datetime.datetime.now()
record.create_by = current_user.id
record.datasource = chat.datasource
record.engine_type = chat.engine_type
record.ai_modal_id = question.ai_modal_id
result = ChatRecord(**record.model_dump())
session.add(record)
session.flush()
session.refresh(record)
result.id = record.id
session.commit()
return result
def save_analysis_predict_record(session: SessionDep, base_record: ChatRecord, action_type: str) -> ChatRecord:
record = ChatRecord()
record.question = base_record.question
record.chat_id = base_record.chat_id
record.datasource = base_record.datasource
record.engine_type = base_record.engine_type
record.ai_modal_id = base_record.ai_modal_id
record.create_time = datetime.datetime.now()
record.create_by = base_record.create_by
record.chart = base_record.chart
record.data = base_record.data
if action_type == 'analysis':
record.analysis_record_id = base_record.id
elif action_type == 'predict':
record.predict_record_id = base_record.id
result = ChatRecord(**record.model_dump())
session.add(record)
session.flush()
session.refresh(record)
result.id = record.id
session.commit()
return result
def start_log(session: SessionDep, ai_modal_id: int, ai_modal_name: str, operate: OperationEnum, record_id: int,
full_message: list[dict]) -> ChatLog:
log = ChatLog(type=TypeEnum.CHAT, operate=operate, pid=record_id, ai_modal_id=ai_modal_id, base_modal=ai_modal_name,
messages=full_message, start_time=datetime.datetime.now())
result = ChatLog(**log.model_dump())
session.add(log)
session.flush()
session.refresh(log)
result.id = log.id
session.commit()
return result
def end_log(session: SessionDep, log: ChatLog, full_message: list[dict], reasoning_content: str = None,
token_usage=None) -> ChatLog:
if token_usage is None:
token_usage = {}
log.messages = full_message
log.token_usage = token_usage
log.finish_time = datetime.datetime.now()
log.reasoning_content = reasoning_content if reasoning_content and len(reasoning_content.strip()) > 0 else None
stmt = update(ChatLog).where(and_(ChatLog.id == log.id)).values(
messages=log.messages,
token_usage=log.token_usage,
finish_time=log.finish_time,
reasoning_content=log.reasoning_content
)
session.execute(stmt)
session.commit()
return log
def save_sql_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values(
sql_answer=answer,
)
session.execute(stmt)
session.commit()
record = get_chat_record_by_id(session, record_id)
return record
def save_analysis_answer(session: SessionDep, record_id: int, answer: str = '') -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values(
analysis=answer,
)
session.execute(stmt)
session.commit()
record = get_chat_record_by_id(session, record_id)
return record
def save_predict_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values(
predict=answer,
)
session.execute(stmt)
session.commit()
record = get_chat_record_by_id(session, record_id)
return record
def save_select_datasource_answer(session: SessionDep, record_id: int, answer: str,
datasource: int = None, engine_type: str = None) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.datasource_select_answer = answer
if datasource:
record.datasource = datasource
record.engine_type = engine_type
result = ChatRecord(**record.model_dump())
if datasource:
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
datasource_select_answer=record.datasource_select_answer,
datasource=record.datasource,
engine_type=record.engine_type,
)
else:
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
datasource_select_answer=record.datasource_select_answer,
)
session.execute(stmt)
session.commit()
return result
def save_recommend_question_answer(session: SessionDep, record_id: int,
answer: dict = None) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
recommended_question_answer = orjson.dumps(answer).decode()
json_str = '[]'
if answer and answer.get('content') and answer.get('content') != '':
try:
json_str = extract_nested_json(answer.get('content'))
if not json_str:
json_str = '[]'
except Exception as e:
pass
recommended_question = json_str
stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values(
recommended_question_answer=recommended_question_answer,
recommended_question=recommended_question,
)
session.execute(stmt)
session.commit()
record = get_chat_record_by_id(session, record_id)
record.recommended_question_answer = recommended_question_answer
record.recommended_question = recommended_question
return record
def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.sql = sql
result = ChatRecord(**record.model_dump())
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
sql=record.sql
)
session.execute(stmt)
session.commit()
return result
def save_chart_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values(
chart_answer=answer,
)
session.execute(stmt)
session.commit()
record = get_chat_record_by_id(session, record_id)
return record
def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.chart = chart
result = ChatRecord(**record.model_dump())
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
chart=record.chart
)
session.execute(stmt)
session.commit()
return result
def save_predict_data(session: SessionDep, record_id: int, data: str = '') -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.predict_data = data
result = ChatRecord(**record.model_dump())
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
predict_data=record.predict_data
)
session.execute(stmt)
session.commit()
return result
def save_error_message(session: SessionDep, record_id: int, message: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.error = message
record.finish = True
record.finish_time = datetime.datetime.now()
result = ChatRecord(**record.model_dump())
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
error=record.error,
finish=record.finish,
finish_time=record.finish_time
)
session.execute(stmt)
session.commit()
return result
def save_sql_exec_data(session: SessionDep, record_id: int, data: str) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.data = data
result = ChatRecord(**record.model_dump())
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
data=record.data,
)
session.execute(stmt)
session.commit()
return result
def finish_record(session: SessionDep, record_id: int) -> ChatRecord:
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)
record.finish = True
record.finish_time = datetime.datetime.now()
result = ChatRecord(**record.model_dump())
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
finish=record.finish,
finish_time=record.finish_time
)
session.execute(stmt)
session.commit()
return result
def get_old_questions(session: SessionDep, datasource: int):
records = []
if not datasource:
return records
stmt = select(ChatRecord.question).where(
and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None),
ChatRecord.error.is_(None))).order_by(
ChatRecord.create_time.desc()).limit(20)
result = session.execute(stmt)
for r in result:
records.append(r.question)
return records