diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py new file mode 100644 index 0000000..0108155 --- /dev/null +++ b/backend/apps/chat/curd/chat.py @@ -0,0 +1,725 @@ +import datetime +from typing import List + +import orjson +import sqlparse +from sqlalchemy import and_, select, update +from sqlalchemy.orm import aliased + +from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \ + TypeEnum, OperationEnum, ChatRecordResult +from apps.datasource.models.datasource import CoreDatasource +from apps.system.crud.assistant import AssistantOutDsFactory +from common.core.deps import CurrentAssistant, SessionDep, CurrentUser +from common.utils.utils import extract_nested_json + + +def get_chat_record_by_id(session: SessionDep, record_id: int): + record: ChatRecord | None = None + + stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type, + ChatRecord.ai_modal_id, ChatRecord.create_by).where( + and_(ChatRecord.id == record_id)) + result = session.execute(stmt) + for r in result: + record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource, + engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by) + return record + + +def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]: + oid = current_user.oid if current_user.oid is not None else 1 + chart_list = session.query(Chat).filter(and_(Chat.create_by == current_user.id, Chat.oid == oid)).order_by( + Chat.create_time.desc()).all() + return chart_list + + +def rename_chat(session: SessionDep, rename_object: RenameChat) -> str: + chat = session.get(Chat, rename_object.id) + if not chat: + raise Exception(f"Chat with id {rename_object.id} not found") + + chat.brief = rename_object.brief.strip()[:20] + session.add(chat) + session.flush() + session.refresh(chat) + + brief = chat.brief + session.commit() + return brief + + +def delete_chat(session, chart_id) -> str: + chat = session.query(Chat).filter(Chat.id == chart_id).first() + if not chat: + return f'Chat with id {chart_id} has been deleted' + + session.delete(chat) + session.commit() + + return f'Chat with id {chart_id} has been deleted' + + +def get_chart_config(session: SessionDep, chart_record_id: int): + stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chart_record_id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.chart) + except Exception: + pass + return {} + + +def get_last_execute_sql_error(session: SessionDep, chart_id: int): + stmt = select(ChatRecord.error).where(and_(ChatRecord.chat_id == chart_id)).order_by( + ChatRecord.create_time.desc()).limit(1) + res = session.execute(stmt).scalar() + if res: + try: + obj = orjson.loads(res) + if obj.get('type') and obj.get('type') == 'exec-sql-err': + return obj.get('traceback') + except Exception: + pass + + return None + + +def get_chat_chart_data(session: SessionDep, chart_record_id: int): + stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chart_record_id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.data) + except Exception: + pass + return [] + + +def get_chat_predict_data(session: SessionDep, chart_record_id: int): + stmt = select(ChatRecord.predict_data).where(and_(ChatRecord.id == chart_record_id)) + res = session.execute(stmt) + for row in res: + try: + return orjson.loads(row.predict_data) + except Exception: + pass + return [] + + +def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_user: CurrentUser, + current_assistant: CurrentAssistant) -> ChatInfo: + return get_chat_with_records(session, chart_id, current_user, current_assistant, True) + + +dynamic_ds_types = [1, 3] + + +def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser, + current_assistant: CurrentAssistant, with_data: bool = False) -> ChatInfo: + chat = session.get(Chat, chart_id) + if not chat: + raise Exception(f"Chat with id {chart_id} not found") + + chat_info = ChatInfo(**chat.model_dump()) + + if current_assistant and current_assistant.type in dynamic_ds_types: + out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant) + ds = out_ds_instance.get_ds(chat.datasource) + else: + ds = session.get(CoreDatasource, chat.datasource) if chat.datasource else None + + if not ds: + chat_info.datasource_exists = False + chat_info.datasource_name = 'Datasource not exist' + else: + chat_info.datasource_exists = True + chat_info.datasource_name = ds.name + chat_info.ds_type = ds.type + + sql_alias_log = aliased(ChatLog) + chart_alias_log = aliased(ChatLog) + analysis_alias_log = aliased(ChatLog) + predict_alias_log = aliased(ChatLog) + + stmt = (select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time, + ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, + ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict, + ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id, + ChatRecord.recommended_question, ChatRecord.first_chat, + ChatRecord.finish, ChatRecord.error, + sql_alias_log.reasoning_content.label('sql_reasoning_content'), + chart_alias_log.reasoning_content.label('chart_reasoning_content'), + analysis_alias_log.reasoning_content.label('analysis_reasoning_content'), + predict_alias_log.reasoning_content.label('predict_reasoning_content') + ) + .outerjoin(sql_alias_log, and_(sql_alias_log.pid == ChatRecord.id, + sql_alias_log.type == TypeEnum.CHAT, + sql_alias_log.operate == OperationEnum.GENERATE_SQL)) + .outerjoin(chart_alias_log, and_(chart_alias_log.pid == ChatRecord.id, + chart_alias_log.type == TypeEnum.CHAT, + chart_alias_log.operate == OperationEnum.GENERATE_CHART)) + .outerjoin(analysis_alias_log, and_(analysis_alias_log.pid == ChatRecord.id, + analysis_alias_log.type == TypeEnum.CHAT, + analysis_alias_log.operate == OperationEnum.ANALYSIS)) + .outerjoin(predict_alias_log, and_(predict_alias_log.pid == ChatRecord.id, + predict_alias_log.type == TypeEnum.CHAT, + predict_alias_log.operate == OperationEnum.PREDICT_DATA)) + .where(and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by( + ChatRecord.create_time)) + if with_data: + stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time, + ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, + ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict, + ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id, + ChatRecord.recommended_question, ChatRecord.first_chat, + ChatRecord.finish, ChatRecord.error, ChatRecord.data, ChatRecord.predict_data).where( + and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by( + ChatRecord.create_time) + + result = session.execute(stmt).all() + record_list: list[ChatRecordResult] = [] + for row in result: + if not with_data: + record_list.append( + ChatRecordResult(id=row.id, chat_id=row.chat_id, create_time=row.create_time, + finish_time=row.finish_time, + question=row.question, sql_answer=row.sql_answer, sql=row.sql, + chart_answer=row.chart_answer, chart=row.chart, + analysis=row.analysis, predict=row.predict, + datasource_select_answer=row.datasource_select_answer, + analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id, + recommended_question=row.recommended_question, first_chat=row.first_chat, + finish=row.finish, error=row.error, + sql_reasoning_content=row.sql_reasoning_content, + chart_reasoning_content=row.chart_reasoning_content, + analysis_reasoning_content=row.analysis_reasoning_content, + predict_reasoning_content=row.predict_reasoning_content, + )) + else: + record_list.append( + ChatRecordResult(id=row.id, chat_id=row.chat_id, create_time=row.create_time, + finish_time=row.finish_time, + question=row.question, sql_answer=row.sql_answer, sql=row.sql, + chart_answer=row.chart_answer, chart=row.chart, + analysis=row.analysis, predict=row.predict, + datasource_select_answer=row.datasource_select_answer, + analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id, + recommended_question=row.recommended_question, first_chat=row.first_chat, + finish=row.finish, error=row.error, data=row.data, predict_data=row.predict_data)) + + result = list(map(format_record, record_list)) + + chat_info.records = result + + return chat_info + + +def format_record(record: ChatRecordResult): + _dict = record.model_dump() + + if record.sql_answer and record.sql_answer.strip() != '' and record.sql_answer.strip()[0] == '{' and \ + record.sql_answer.strip()[-1] == '}': + _obj = orjson.loads(record.sql_answer) + _dict['sql_answer'] = _obj.get('reasoning_content') + if record.sql_reasoning_content and record.sql_reasoning_content.strip() != '': + _dict['sql_answer'] = record.sql_reasoning_content + if record.chart_answer and record.chart_answer.strip() != '' and record.chart_answer.strip()[0] == '{' and \ + record.chart_answer.strip()[-1] == '}': + _obj = orjson.loads(record.chart_answer) + _dict['chart_answer'] = _obj.get('reasoning_content') + if record.chart_reasoning_content and record.chart_reasoning_content.strip() != '': + _dict['chart_answer'] = record.chart_reasoning_content + if record.analysis and record.analysis.strip() != '' and record.analysis.strip()[0] == '{' and \ + record.analysis.strip()[-1] == '}': + _obj = orjson.loads(record.analysis) + _dict['analysis_thinking'] = _obj.get('reasoning_content') + _dict['analysis'] = _obj.get('content') + if record.analysis_reasoning_content and record.analysis_reasoning_content.strip() != '': + _dict['analysis_thinking'] = record.analysis_reasoning_content + if record.predict and record.predict.strip() != '' and record.predict.strip()[0] == '{' and record.predict.strip()[ + -1] == '}': + _obj = orjson.loads(record.predict) + _dict['predict'] = _obj.get('reasoning_content') + _dict['predict_content'] = _obj.get('content') + if record.predict_reasoning_content and record.predict_reasoning_content.strip() != '': + _dict['predict'] = record.predict_reasoning_content + if record.data and record.data.strip() != '': + try: + _obj = orjson.loads(record.data) + _dict['data'] = _obj + except Exception: + pass + if record.predict_data and record.predict_data.strip() != '': + try: + _obj = orjson.loads(record.predict_data) + _dict['predict_data'] = _obj + except Exception: + pass + if record.sql and record.sql.strip() != '': + try: + _dict['sql'] = sqlparse.format(record.sql, reindent=True) + except Exception: + pass + + return _dict + + +def list_generate_sql_logs(session: SessionDep, chart_id: int) -> List[ChatLog]: + stmt = select(ChatLog).where( + and_(ChatLog.pid.in_(select(ChatRecord.id).where(and_(ChatRecord.chat_id == chart_id))), + ChatLog.type == TypeEnum.CHAT, ChatLog.operate == OperationEnum.GENERATE_SQL)).order_by( + ChatLog.start_time) + result = session.execute(stmt).all() + _list = [] + for row in result: + for r in row: + _list.append(ChatLog(**r.model_dump())) + return _list + + +def list_generate_chart_logs(session: SessionDep, chart_id: int) -> List[ChatLog]: + stmt = select(ChatLog).where( + and_(ChatLog.pid.in_(select(ChatRecord.id).where(and_(ChatRecord.chat_id == chart_id))), + ChatLog.type == TypeEnum.CHAT, ChatLog.operate == OperationEnum.GENERATE_CHART)).order_by( + ChatLog.start_time) + result = session.execute(stmt).all() + _list = [] + for row in result: + for r in row: + _list.append(ChatLog(**r.model_dump())) + return _list + + +def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat, + require_datasource: bool = True) -> ChatInfo: + if not create_chat_obj.datasource and require_datasource: + raise Exception("Datasource cannot be None") + + if not create_chat_obj.question or create_chat_obj.question.strip() == '': + create_chat_obj.question = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + chat = Chat(create_time=datetime.datetime.now(), + create_by=current_user.id, + oid=current_user.oid if current_user.oid is not None else 1, + brief=create_chat_obj.question.strip()[:20], + origin=create_chat_obj.origin if create_chat_obj.origin is not None else 0) + ds: CoreDatasource | None = None + if create_chat_obj.datasource: + chat.datasource = create_chat_obj.datasource + ds = session.get(CoreDatasource, create_chat_obj.datasource) + + if not ds: + raise Exception(f"Datasource with id {create_chat_obj.datasource} not found") + + chat.engine_type = ds.type_name + else: + chat.engine_type = '' + + chat_info = ChatInfo(**chat.model_dump()) + + session.add(chat) + session.flush() + session.refresh(chat) + chat_info.id = chat.id + session.commit() + + if ds: + chat_info.datasource_exists = True + chat_info.datasource_name = ds.name + chat_info.ds_type = ds.type + + if require_datasource and ds: + # generate first empty record + record = ChatRecord() + record.chat_id = chat.id + record.datasource = ds.id + record.engine_type = ds.type_name + record.first_chat = True + record.finish = True + record.create_time = datetime.datetime.now() + record.create_by = current_user.id + + _record = ChatRecord(**record.model_dump()) + + session.add(record) + session.flush() + session.refresh(record) + _record.id = record.id + session.commit() + + chat_info.records.append(_record) + + return chat_info + + +def save_question(session: SessionDep, current_user: CurrentUser, question: ChatQuestion) -> ChatRecord: + if not question.chat_id: + raise Exception("ChatId cannot be None") + if not question.question or question.question.strip() == '': + raise Exception("Question cannot be Empty") + + # chat = session.query(Chat).filter(Chat.id == question.chat_id).first() + chat: Chat = session.get(Chat, question.chat_id) + if not chat: + raise Exception(f"Chat with id {question.chat_id} not found") + + record = ChatRecord() + record.question = question.question + record.chat_id = chat.id + record.create_time = datetime.datetime.now() + record.create_by = current_user.id + record.datasource = chat.datasource + record.engine_type = chat.engine_type + record.ai_modal_id = question.ai_modal_id + + result = ChatRecord(**record.model_dump()) + + session.add(record) + session.flush() + session.refresh(record) + result.id = record.id + session.commit() + + return result + + +def save_analysis_predict_record(session: SessionDep, base_record: ChatRecord, action_type: str) -> ChatRecord: + record = ChatRecord() + record.question = base_record.question + record.chat_id = base_record.chat_id + record.datasource = base_record.datasource + record.engine_type = base_record.engine_type + record.ai_modal_id = base_record.ai_modal_id + record.create_time = datetime.datetime.now() + record.create_by = base_record.create_by + record.chart = base_record.chart + record.data = base_record.data + + if action_type == 'analysis': + record.analysis_record_id = base_record.id + elif action_type == 'predict': + record.predict_record_id = base_record.id + + result = ChatRecord(**record.model_dump()) + + session.add(record) + session.flush() + session.refresh(record) + result.id = record.id + session.commit() + + return result + + +def start_log(session: SessionDep, ai_modal_id: int, ai_modal_name: str, operate: OperationEnum, record_id: int, + full_message: list[dict]) -> ChatLog: + log = ChatLog(type=TypeEnum.CHAT, operate=operate, pid=record_id, ai_modal_id=ai_modal_id, base_modal=ai_modal_name, + messages=full_message, start_time=datetime.datetime.now()) + + result = ChatLog(**log.model_dump()) + + session.add(log) + session.flush() + session.refresh(log) + result.id = log.id + session.commit() + + return result + + +def end_log(session: SessionDep, log: ChatLog, full_message: list[dict], reasoning_content: str = None, + token_usage=None) -> ChatLog: + if token_usage is None: + token_usage = {} + log.messages = full_message + log.token_usage = token_usage + log.finish_time = datetime.datetime.now() + log.reasoning_content = reasoning_content if reasoning_content and len(reasoning_content.strip()) > 0 else None + + stmt = update(ChatLog).where(and_(ChatLog.id == log.id)).values( + messages=log.messages, + token_usage=log.token_usage, + finish_time=log.finish_time, + reasoning_content=log.reasoning_content + ) + session.execute(stmt) + session.commit() + + return log + + +def save_sql_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + sql_answer=answer, + ) + + session.execute(stmt) + + session.commit() + + record = get_chat_record_by_id(session, record_id) + + return record + + +def save_analysis_answer(session: SessionDep, record_id: int, answer: str = '') -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + analysis=answer, + ) + + session.execute(stmt) + + session.commit() + + record = get_chat_record_by_id(session, record_id) + + return record + + +def save_predict_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + predict=answer, + ) + + session.execute(stmt) + + session.commit() + + record = get_chat_record_by_id(session, record_id) + + return record + + +def save_select_datasource_answer(session: SessionDep, record_id: int, answer: str, + datasource: int = None, engine_type: str = None) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.datasource_select_answer = answer + + if datasource: + record.datasource = datasource + record.engine_type = engine_type + + result = ChatRecord(**record.model_dump()) + + if datasource: + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + datasource_select_answer=record.datasource_select_answer, + datasource=record.datasource, + engine_type=record.engine_type, + ) + else: + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + datasource_select_answer=record.datasource_select_answer, + ) + + session.execute(stmt) + + session.commit() + + return result + + +def save_recommend_question_answer(session: SessionDep, record_id: int, + answer: dict = None) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + + recommended_question_answer = orjson.dumps(answer).decode() + + json_str = '[]' + if answer and answer.get('content') and answer.get('content') != '': + try: + json_str = extract_nested_json(answer.get('content')) + + if not json_str: + json_str = '[]' + except Exception as e: + pass + recommended_question = json_str + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + recommended_question_answer=recommended_question_answer, + recommended_question=recommended_question, + ) + + session.execute(stmt) + + session.commit() + + record = get_chat_record_by_id(session, record_id) + record.recommended_question_answer = recommended_question_answer + record.recommended_question = recommended_question + + return record + + +def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + + record = get_chat_record_by_id(session, record_id) + + record.sql = sql + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + sql=record.sql + ) + + session.execute(stmt) + + session.commit() + + return result + + +def save_chart_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record_id)).values( + chart_answer=answer, + ) + + session.execute(stmt) + + session.commit() + + record = get_chat_record_by_id(session, record_id) + + return record + + +def save_chart(session: SessionDep, record_id: int, chart: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.chart = chart + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + chart=record.chart + ) + + session.execute(stmt) + + session.commit() + + return result + + +def save_predict_data(session: SessionDep, record_id: int, data: str = '') -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.predict_data = data + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + predict_data=record.predict_data + ) + + session.execute(stmt) + + session.commit() + + return result + + +def save_error_message(session: SessionDep, record_id: int, message: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.error = message + record.finish = True + record.finish_time = datetime.datetime.now() + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + error=record.error, + finish=record.finish, + finish_time=record.finish_time + ) + + session.execute(stmt) + + session.commit() + + return result + + +def save_sql_exec_data(session: SessionDep, record_id: int, data: str) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.data = data + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + data=record.data, + ) + + session.execute(stmt) + + session.commit() + + return result + + +def finish_record(session: SessionDep, record_id: int) -> ChatRecord: + if not record_id: + raise Exception("Record id cannot be None") + record = get_chat_record_by_id(session, record_id) + + record.finish = True + record.finish_time = datetime.datetime.now() + + result = ChatRecord(**record.model_dump()) + + stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values( + finish=record.finish, + finish_time=record.finish_time + ) + + session.execute(stmt) + + session.commit() + + return result + + +def get_old_questions(session: SessionDep, datasource: int): + records = [] + if not datasource: + return records + stmt = select(ChatRecord.question).where( + and_(ChatRecord.datasource == datasource, ChatRecord.question.isnot(None), + ChatRecord.error.is_(None))).order_by( + ChatRecord.create_time.desc()).limit(20) + result = session.execute(stmt) + for r in result: + records.append(r.question) + return records