Files
SQLBot/backend/apps/chat/task/llm.py

1287 lines
64 KiB
Python
Raw Permalink Normal View History

2025-09-08 16:36:01 +08:00
import concurrent
import json
import os
import traceback
import urllib.parse
import warnings
from concurrent.futures import ThreadPoolExecutor, Future
from datetime import datetime
from typing import Any, List, Optional, Union, Dict
import numpy as np
import orjson
import pandas as pd
import requests
import sqlparse
from langchain.chat_models.base import BaseChatModel
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlmodel import create_engine, Session
from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config
from apps.chat.curd.chat import save_question, save_sql_answer, save_sql, \
save_error_message, save_sql_exec_data, save_chart_answer, save_chart, \
finish_record, save_analysis_answer, save_predict_answer, save_predict_data, \
save_select_datasource_answer, save_recommend_question_answer, \
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
get_last_execute_sql_error
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
from apps.datasource.models.datasource import CoreDatasource
from apps.db.db import exec_sql, get_version, check_connection
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
from apps.system.schemas.system_schema import AssistantOutDsSchema
from apps.terminology.curd.terminology import get_terminology_template
from common.core.config import settings
from common.core.deps import CurrentAssistant, CurrentUser
from common.error import SingleMessageError, SQLBotDBError, ParseSQLResultError, SQLBotDBConnectionError
from common.utils.utils import SQLBotLogUtil, extract_nested_json, prepare_for_orjson
warnings.filterwarnings("ignore")
base_message_count_limit = 6
executor = ThreadPoolExecutor(max_workers=200)
dynamic_ds_types = [1, 3]
dynamic_subsql_prefix = 'select * from sqlbot_dynamic_temp_table_'
class LLMService:
ds: CoreDatasource
chat_question: ChatQuestion
record: ChatRecord
config: LLMConfig
llm: BaseChatModel
sql_message: List[Union[BaseMessage, dict[str, Any]]] = []
chart_message: List[Union[BaseMessage, dict[str, Any]]] = []
session: Session
current_user: CurrentUser
current_assistant: Optional[CurrentAssistant] = None
out_ds_instance: Optional[AssistantOutDs] = None
change_title: bool = False
generate_sql_logs: List[ChatLog] = []
generate_chart_logs: List[ChatLog] = []
current_logs: dict[OperationEnum, ChatLog] = {}
chunk_list: List[str] = []
future: Future
last_execute_sql_error: str = None
def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False,
config: LLMConfig = None):
self.chunk_list = []
engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
session_maker = sessionmaker(bind=engine)
self.session = session_maker()
self.session.exec = self.session.exec if hasattr(self.session, "exec") else self.session.execute
self.current_user = current_user
self.current_assistant = current_assistant
# chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
chat_id = chat_question.chat_id
chat: Chat | None = self.session.get(Chat, chat_id)
if not chat:
raise SingleMessageError(f"Chat with id {chat_id} not found")
ds: CoreDatasource | AssistantOutDsSchema | None = None
if chat.datasource:
# Get available datasource
# ds = self.session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
if current_assistant and current_assistant.type in dynamic_ds_types:
self.out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant)
ds = self.out_ds_instance.get_ds(chat.datasource)
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = ds.type + get_version(ds)
chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id)
else:
ds = self.session.get(CoreDatasource, chat.datasource)
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds)
self.generate_sql_logs = list_generate_sql_logs(session=self.session, chart_id=chat_id)
self.generate_chart_logs = list_generate_chart_logs(session=self.session, chart_id=chat_id)
self.change_title = len(self.generate_sql_logs) == 0
chat_question.lang = get_lang_name(current_user.language)
self.ds = (ds if isinstance(ds, AssistantOutDsSchema) else CoreDatasource(**ds.model_dump())) if ds else None
self.chat_question = chat_question
self.config = config
if no_reasoning:
# only work while using qwen
if self.config.additional_params:
if self.config.additional_params.get('extra_body'):
if self.config.additional_params.get('extra_body').get('enable_thinking'):
del self.config.additional_params['extra_body']['enable_thinking']
self.chat_question.ai_modal_id = self.config.model_id
self.chat_question.ai_modal_name = self.config.model_name
# Create LLM instance through factory
llm_instance = LLMFactory.create_llm(self.config)
self.llm = llm_instance.llm
# get last_execute_sql_error
last_execute_sql_error = get_last_execute_sql_error(self.session, self.chat_question.chat_id)
if last_execute_sql_error:
self.chat_question.error_msg = f'''<error-msg>
{last_execute_sql_error}
</error-msg>'''
else:
self.chat_question.error_msg = ''
self.init_messages()
@classmethod
async def create(cls, *args, **kwargs):
config: LLMConfig = await get_default_config()
instance = cls(*args, **kwargs, config=config)
return instance
def is_running(self, timeout=0.5):
try:
r = concurrent.futures.wait([self.future], timeout)
if len(r.not_done) > 0:
return True
else:
return False
except Exception as e:
return True
def init_messages(self):
last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len(
self.generate_sql_logs) > 0 else []
# todo maybe can configure
count_limit = 0 - base_message_count_limit
self.sql_message = []
# add sys prompt
self.sql_message.append(SystemMessage(content=self.chat_question.sql_sys_question()))
if last_sql_messages is not None and len(last_sql_messages) > 0:
# limit count
for last_sql_message in last_sql_messages[count_limit:]:
_msg: BaseMessage
if last_sql_message['type'] == 'human':
_msg = HumanMessage(content=last_sql_message['content'])
self.sql_message.append(_msg)
elif last_sql_message['type'] == 'ai':
_msg = AIMessage(content=last_sql_message['content'])
self.sql_message.append(_msg)
last_chart_messages: List[dict[str, Any]] = self.generate_chart_logs[-1].messages if len(
self.generate_chart_logs) > 0 else []
self.chart_message = []
# add sys prompt
self.chart_message.append(SystemMessage(content=self.chat_question.chart_sys_question()))
if last_chart_messages is not None and len(last_chart_messages) > 0:
# limit count
for last_chart_message in last_chart_messages:
_msg: BaseMessage
if last_chart_message.get('type') == 'human':
_msg = HumanMessage(content=last_chart_message.get('content'))
self.chart_message.append(_msg)
elif last_chart_message.get('type') == 'ai':
_msg = AIMessage(content=last_chart_message.get('content'))
self.chart_message.append(_msg)
def init_record(self) -> ChatRecord:
self.record = save_question(session=self.session, current_user=self.current_user, question=self.chat_question)
return self.record
def get_record(self):
return self.record
def set_record(self, record: ChatRecord):
self.record = record
def get_fields_from_chart(self):
chart_info = get_chart_config(self.session, self.record.id)
fields = []
if chart_info.get('columns') and len(chart_info.get('columns')) > 0:
for column in chart_info.get('columns'):
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
if chart_info.get('axis'):
for _type in ['x', 'y', 'series']:
if chart_info.get('axis').get(_type):
column = chart_info.get('axis').get(_type)
column_str = column.get('value')
if column.get('value') != column.get('name'):
column_str = column_str + '(' + column.get('name') + ')'
fields.append(column_str)
return fields
def generate_analysis(self):
fields = self.get_fields_from_chart()
self.chat_question.fields = orjson.dumps(fields).decode()
data = get_chat_chart_data(self.session, self.record.id)
self.chat_question.data = orjson.dumps(data.get('data')).decode()
analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = []
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
self.current_user.oid)
analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
self.current_logs[OperationEnum.ANALYSIS] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.ANALYSIS,
record_id=self.record.id,
full_message=[
{'type': msg.type,
'content': msg.content} for
msg
in analysis_msg])
full_thinking_text = ''
full_analysis_text = ''
res = self.llm.stream(analysis_msg)
token_usage = {}
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_analysis_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
analysis_msg.append(AIMessage(full_analysis_text))
self.current_logs[OperationEnum.ANALYSIS] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.ANALYSIS],
full_message=[
{'type': msg.type,
'content': msg.content}
for msg in analysis_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)
self.record = save_analysis_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_analysis_text}).decode())
def generate_predict(self):
fields = self.get_fields_from_chart()
self.chat_question.fields = orjson.dumps(fields).decode()
data = get_chat_chart_data(self.session, self.record.id)
self.chat_question.data = orjson.dumps(data.get('data')).decode()
predict_msg: List[Union[BaseMessage, dict[str, Any]]] = []
predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question()))
predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question()))
self.current_logs[OperationEnum.PREDICT_DATA] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.PREDICT_DATA,
record_id=self.record.id,
full_message=[
{'type': msg.type,
'content': msg.content} for
msg
in predict_msg])
full_thinking_text = ''
full_predict_text = ''
res = self.llm.stream(predict_msg)
token_usage = {}
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_predict_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
predict_msg.append(AIMessage(full_predict_text))
self.record = save_predict_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_predict_text}).decode())
self.current_logs[OperationEnum.PREDICT_DATA] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.PREDICT_DATA],
full_message=[
{'type': msg.type,
'content': msg.content}
for msg in predict_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)
def generate_recommend_questions_task(self):
# get schema
if self.ds and not self.chat_question.db_schema:
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
current_user=self.current_user, ds=self.ds)
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
old_questions = list(map(lambda q: q.strip(), get_old_questions(self.session, self.record.datasource)))
guess_msg.append(
HumanMessage(content=self.chat_question.guess_user_question(orjson.dumps(old_questions).decode())))
self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.GENERATE_RECOMMENDED_QUESTIONS,
record_id=self.record.id,
full_message=[
{'type': msg.type,
'content': msg.content} for
msg
in guess_msg])
full_thinking_text = ''
full_guess_text = ''
token_usage = {}
res = self.llm.stream(guess_msg)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_guess_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
guess_msg.append(AIMessage(full_guess_text))
self.current_logs[OperationEnum.GENERATE_RECOMMENDED_QUESTIONS] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.GENERATE_RECOMMENDED_QUESTIONS],
full_message=[
{'type': msg.type,
'content': msg.content}
for msg in guess_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)
self.record = save_recommend_question_answer(session=self.session, record_id=self.record.id,
answer={'content': full_guess_text})
yield {'recommended_question': self.record.recommended_question}
def select_datasource(self):
datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = []
datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question()))
if self.current_assistant and self.current_assistant.type != 4:
_ds_list = get_assistant_ds(session=self.session, llm_service=self)
else:
oid: str = self.current_user.oid
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
CoreDatasource.oid == oid)
_ds_list = [
{
"id": ds.id,
"name": ds.name,
"description": ds.description
}
for ds in self.session.exec(stmt)
]
""" _ds_list = self.session.exec(select(CoreDatasource).options(
load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all() """
if not _ds_list:
raise SingleMessageError('No available datasource configuration found')
ignore_auto_select = _ds_list and len(_ds_list) == 1
# ignore auto select ds
full_thinking_text = ''
full_text = ''
if not ignore_auto_select:
_ds_list_dict = []
for _ds in _ds_list:
_ds_list_dict.append(_ds)
datasource_msg.append(
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.CHOOSE_DATASOURCE,
record_id=self.record.id,
full_message=[{'type': msg.type,
'content': msg.content} for
msg in datasource_msg])
token_usage = {}
res = self.llm.stream(datasource_msg)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
datasource_msg.append(AIMessage(full_text))
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.CHOOSE_DATASOURCE],
full_message=[
{'type': msg.type, 'content': msg.content}
for msg in datasource_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)
json_str = extract_nested_json(full_text)
_error: Exception | None = None
_datasource: int | None = None
_engine_type: str | None = None
try:
data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str)
if data.get('id') and data.get('id') != 0:
_datasource = data['id']
_chat = self.session.get(Chat, self.record.chat_id)
_chat.datasource = _datasource
if self.current_assistant and self.current_assistant.type in dynamic_ds_types:
_ds = self.out_ds_instance.get_ds(data['id'])
self.ds = _ds
self.chat_question.engine = _ds.type + get_version(self.ds)
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type
else:
_ds = self.session.get(CoreDatasource, _datasource)
if not _ds:
_datasource = None
raise SingleMessageError(f"Datasource configuration with id {_datasource} not found")
self.ds = CoreDatasource(**_ds.model_dump())
self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version(
self.ds)
self.chat_question.db_schema = get_table_schema(session=self.session,
current_user=self.current_user, ds=self.ds)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type_name
# save chat
self.session.add(_chat)
self.session.flush()
self.session.refresh(_chat)
self.session.commit()
elif data['fail']:
raise SingleMessageError(data['fail'])
else:
raise SingleMessageError('No available datasource configuration found')
except Exception as e:
_error = e
if not ignore_auto_select:
self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_text}).decode(),
datasource=_datasource,
engine_type=_engine_type)
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
self.ds.oid if isinstance(self.ds,
CoreDatasource) else 1)
self.init_messages()
if _error:
raise _error
def generate_sql(self):
# append current question
self.sql_message.append(HumanMessage(
self.chat_question.sql_user_question(current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))))
self.current_logs[OperationEnum.GENERATE_SQL] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.GENERATE_SQL,
record_id=self.record.id,
full_message=[
{'type': msg.type, 'content': msg.content} for msg
in self.sql_message])
full_thinking_text = ''
full_sql_text = ''
token_usage = {}
res = self.llm.stream(self.sql_message)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_sql_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
self.sql_message.append(AIMessage(full_sql_text))
self.current_logs[OperationEnum.GENERATE_SQL] = end_log(session=self.session,
log=self.current_logs[OperationEnum.GENERATE_SQL],
full_message=[{'type': msg.type, 'content': msg.content}
for msg in self.sql_message],
reasoning_content=full_thinking_text,
token_usage=token_usage)
self.record = save_sql_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_sql_text}).decode())
def generate_with_sub_sql(self, sql, sub_mappings: list):
sub_query = json.dumps(sub_mappings, ensure_ascii=False)
self.chat_question.sql = sql
self.chat_question.sub_query = sub_query
dynamic_sql_msg: List[Union[BaseMessage, dict[str, Any]]] = []
dynamic_sql_msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question()))
dynamic_sql_msg.append(HumanMessage(content=self.chat_question.dynamic_user_question()))
self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.GENERATE_DYNAMIC_SQL,
record_id=self.record.id,
full_message=[{'type': msg.type,
'content': msg.content}
for
msg in dynamic_sql_msg])
full_thinking_text = ''
full_dynamic_text = ''
res = self.llm.stream(dynamic_sql_msg)
token_usage = {}
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_dynamic_text += chunk.content
get_token_usage(chunk, token_usage)
dynamic_sql_msg.append(AIMessage(full_dynamic_text))
self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.GENERATE_DYNAMIC_SQL],
full_message=[
{'type': msg.type,
'content': msg.content}
for msg in dynamic_sql_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)
SQLBotLogUtil.info(full_dynamic_text)
return full_dynamic_text
def generate_assistant_dynamic_sql(self, sql, tables: List):
ds: AssistantOutDsSchema = self.ds
sub_query = []
result_dict = {}
for table in ds.tables:
if table.name in tables and table.sql:
#sub_query.append({"table": table.name, "query": table.sql})
result_dict[table.name] = table.sql
sub_query.append({"table": table.name, "query": f'{dynamic_subsql_prefix}{table.name}'})
if not sub_query:
return None
temp_sql_text = self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query)
result_dict['sqlbot_temp_sql_text'] = temp_sql_text
return result_dict
def build_table_filter(self, sql: str, filters: list):
filter = json.dumps(filters, ensure_ascii=False)
self.chat_question.sql = sql
self.chat_question.filter = filter
permission_sql_msg: List[Union[BaseMessage, dict[str, Any]]] = []
permission_sql_msg.append(SystemMessage(content=self.chat_question.filter_sys_question()))
permission_sql_msg.append(HumanMessage(content=self.chat_question.filter_user_question()))
self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.GENERATE_SQL_WITH_PERMISSIONS,
record_id=self.record.id,
full_message=[
{'type': msg.type,
'content': msg.content} for
msg
in permission_sql_msg])
full_thinking_text = ''
full_filter_text = ''
res = self.llm.stream(permission_sql_msg)
token_usage = {}
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_filter_text += chunk.content
# yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
permission_sql_msg.append(AIMessage(full_filter_text))
self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = end_log(session=self.session,
log=self.current_logs[
OperationEnum.GENERATE_SQL_WITH_PERMISSIONS],
full_message=[
{'type': msg.type,
'content': msg.content}
for msg in permission_sql_msg],
reasoning_content=full_thinking_text,
token_usage=token_usage)
SQLBotLogUtil.info(full_filter_text)
return full_filter_text
def generate_filter(self, sql: str, tables: List):
filters = get_row_permission_filters(session=self.session, current_user=self.current_user, ds=self.ds,
tables=tables)
if not filters:
return None
return self.build_table_filter(sql=sql, filters=filters)
def generate_assistant_filter(self, sql, tables: List):
ds: AssistantOutDsSchema = self.ds
filters = []
for table in ds.tables:
if table.name in tables and table.rule:
filters.append({"table": table.name, "filter": table.rule})
if not filters:
return None
return self.build_table_filter(sql=sql, filters=filters)
def generate_chart(self, chart_type: Optional[str] = ''):
# append current question
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type)))
self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=self.session,
ai_modal_id=self.chat_question.ai_modal_id,
ai_modal_name=self.chat_question.ai_modal_name,
operate=OperationEnum.GENERATE_CHART,
record_id=self.record.id,
full_message=[
{'type': msg.type, 'content': msg.content} for
msg
in self.chart_message])
full_thinking_text = ''
full_chart_text = ''
token_usage = {}
res = self.llm.stream(self.chart_message)
for chunk in res:
SQLBotLogUtil.info(chunk)
reasoning_content_chunk = ''
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
# else:
# reasoning_content_chunk = chunk.get('reasoning_content')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
full_thinking_text += reasoning_content_chunk
full_chart_text += chunk.content
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
get_token_usage(chunk, token_usage)
self.chart_message.append(AIMessage(full_chart_text))
self.record = save_chart_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_chart_text}).decode())
self.current_logs[OperationEnum.GENERATE_CHART] = end_log(session=self.session,
log=self.current_logs[OperationEnum.GENERATE_CHART],
full_message=[
{'type': msg.type, 'content': msg.content}
for msg in self.chart_message],
reasoning_content=full_thinking_text,
token_usage=token_usage)
@staticmethod
def check_sql(res: str) -> tuple[str, Optional[list]]:
json_str = extract_nested_json(res)
if json_str is None:
raise SingleMessageError(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
sql: str
data: dict
try:
data = orjson.loads(json_str)
if data['success']:
sql = data['sql']
else:
message = data['message']
raise SingleMessageError(message)
except SingleMessageError as e:
raise e
except Exception:
raise SingleMessageError(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
if sql.strip() == '':
raise SingleMessageError("SQL query is empty")
return sql, data.get('tables')
@staticmethod
def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
json_str = extract_nested_json(res)
if json_str is None:
return None
chart_type: Optional[str]
data: dict
try:
data = orjson.loads(json_str)
if data['success']:
chart_type = data['chart-type']
else:
return None
except Exception:
return None
return chart_type
def check_save_sql(self, res: str) -> str:
sql, *_ = self.check_sql(res=res)
save_sql(session=self.session, sql=sql, record_id=self.record.id)
self.chat_question.sql = sql
return sql
def check_save_chart(self, res: str) -> Dict[str, Any]:
json_str = extract_nested_json(res)
if json_str is None:
raise SingleMessageError(orjson.dumps({'message': 'Cannot parse chart config from answer',
'traceback': "Cannot parse chart config from answer:\n" + res}).decode())
data: dict
chart: Dict[str, Any] = {}
message = ''
error = False
try:
data = orjson.loads(json_str)
if data['type'] and data['type'] != 'error':
# todo type check
chart = data
if chart.get('columns'):
for v in chart.get('columns'):
v['value'] = v.get('value').lower()
if chart.get('axis'):
if chart.get('axis').get('x'):
chart.get('axis').get('x')['value'] = chart.get('axis').get('x').get('value').lower()
if chart.get('axis').get('y'):
chart.get('axis').get('y')['value'] = chart.get('axis').get('y').get('value').lower()
if chart.get('axis').get('series'):
chart.get('axis').get('series')['value'] = chart.get('axis').get('series').get('value').lower()
elif data['type'] == 'error':
message = data['reason']
error = True
else:
raise Exception('Chart is empty')
except Exception:
error = True
message = orjson.dumps({'message': 'Cannot parse chart config from answer',
'traceback': "Cannot parse chart config from answer:\n" + res}).decode()
if error:
raise SingleMessageError(message)
save_chart(session=self.session, chart=orjson.dumps(chart).decode(), record_id=self.record.id)
return chart
def check_save_predict_data(self, res: str) -> bool:
json_str = extract_nested_json(res)
if not json_str:
json_str = ''
save_predict_data(session=self.session, record_id=self.record.id, data=json_str)
if json_str == '':
return False
return True
def save_error(self, message: str):
return save_error_message(session=self.session, record_id=self.record.id, message=message)
def save_sql_data(self, data_obj: Dict[str, Any]):
try:
data_result = data_obj.get('data')
limit = 1000
if data_result:
data_result = prepare_for_orjson(data_result)
if data_result and len(data_result) > limit:
data_obj['data'] = data_result[:limit]
data_obj['limit'] = limit
else:
data_obj['data'] = data_result
return save_sql_exec_data(session=self.session, record_id=self.record.id,
data=orjson.dumps(data_obj).decode())
except Exception as e:
raise e
def finish(self):
return finish_record(session=self.session, record_id=self.record.id)
def execute_sql(self, sql: str):
"""Execute SQL query
Args:
ds: Data source instance
sql: SQL query statement
Returns:
Query results
"""
SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}")
try:
return exec_sql(self.ds, sql)
except Exception as e:
if isinstance(e, ParseSQLResultError):
raise e
else:
err = traceback.format_exc(limit=1, chain=True)
raise SQLBotDBError(err)
def pop_chunk(self):
try:
chunk = self.chunk_list.pop(0)
return chunk
except IndexError as e:
return None
def await_result(self):
while self.is_running():
while True:
chunk = self.pop_chunk()
if chunk is not None:
yield chunk
else:
break
while True:
chunk = self.pop_chunk()
if chunk is None:
break
yield chunk
def run_task_async(self, in_chat: bool = True):
self.future = executor.submit(self.run_task_cache, in_chat)
def run_task_cache(self, in_chat: bool = True):
for chunk in self.run_task(in_chat):
self.chunk_list.append(chunk)
def run_task(self, in_chat: bool = True):
try:
if self.ds:
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
self.ds.oid if isinstance(self.ds,
CoreDatasource) else 1)
self.init_messages()
# return id
if in_chat:
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'
# return title
if self.change_title:
if self.chat_question.question or self.chat_question.question.strip() != '':
brief = rename_chat(session=self.session,
rename_object=RenameChat(id=self.get_record().chat_id,
brief=self.chat_question.question.strip()[:20]))
if in_chat:
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
# select datasource if datasource is none
if not self.ds:
ds_res = self.select_datasource()
for chunk in ds_res:
SQLBotLogUtil.info(chunk)
if in_chat:
yield 'data:' + orjson.dumps(
{'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'),
'type': 'datasource-result'}).decode() + '\n\n'
if in_chat:
yield 'data:' + orjson.dumps({'id': self.ds.id, 'datasource_name': self.ds.name,
'engine_type': self.ds.type_name or self.ds.type,
'type': 'datasource'}).decode() + '\n\n'
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
current_user=self.current_user,
ds=self.ds)
else:
self.validate_history_ds()
# check connection
connected = check_connection(ds=self.ds, trans=None)
if not connected:
raise SQLBotDBConnectionError('Connect DB failed')
# generate sql
sql_res = self.generate_sql()
full_sql_text = ''
for chunk in sql_res:
full_sql_text += chunk.get('content')
if in_chat:
yield 'data:' + orjson.dumps(
{'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'),
'type': 'sql-result'}).decode() + '\n\n'
if in_chat:
yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n'
# filter sql
SQLBotLogUtil.info(full_sql_text)
chart_type = self.get_chart_type_from_sql_answer(full_sql_text)
use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type in dynamic_ds_types
is_page_embedded: bool = self.current_assistant and self.current_assistant.type == 4
dynamic_sql_result = None
sqlbot_temp_sql_text = None
assistant_dynamic_sql = None
# todo row permission
if ((not self.current_assistant or is_page_embedded) and is_normal_user(self.current_user)) or use_dynamic_ds:
sql, tables = self.check_sql(res=full_sql_text)
sql_result = None
if use_dynamic_ds:
dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables)
sqlbot_temp_sql_text = dynamic_sql_result.get('sqlbot_temp_sql_text') if dynamic_sql_result else None
#sql_result = self.generate_assistant_filter(sql, tables)
else:
sql_result = self.generate_filter(sql, tables) # maybe no sql and tables
if sql_result:
SQLBotLogUtil.info(sql_result)
sql = self.check_save_sql(res=sql_result)
elif dynamic_sql_result and sqlbot_temp_sql_text:
assistant_dynamic_sql = self.check_save_sql(res=sqlbot_temp_sql_text)
else:
sql = self.check_save_sql(res=full_sql_text)
else:
sql = self.check_save_sql(res=full_sql_text)
SQLBotLogUtil.info(sql)
format_sql = sqlparse.format(sql, reindent=True)
if in_chat:
yield 'data:' + orjson.dumps({'content': format_sql, 'type': 'sql'}).decode() + '\n\n'
else:
yield f'```sql\n{format_sql}\n```\n\n'
# execute sql
real_execute_sql = sql
if sqlbot_temp_sql_text and assistant_dynamic_sql:
dynamic_sql_result.pop('sqlbot_temp_sql_text')
for origin_table, subsql in dynamic_sql_result.items():
assistant_dynamic_sql = assistant_dynamic_sql.replace(f'{dynamic_subsql_prefix}{origin_table}', subsql)
real_execute_sql = assistant_dynamic_sql
result = self.execute_sql(sql=real_execute_sql)
self.save_sql_data(data_obj=result)
if in_chat:
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'
# generate chart
chart_res = self.generate_chart(chart_type)
full_chart_text = ''
for chunk in chart_res:
full_chart_text += chunk.get('content')
if in_chat:
yield 'data:' + orjson.dumps(
{'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'),
'type': 'chart-result'}).decode() + '\n\n'
if in_chat:
yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'chart generated'}).decode() + '\n\n'
# filter chart
SQLBotLogUtil.info(full_chart_text)
chart = self.check_save_chart(res=full_chart_text)
SQLBotLogUtil.info(chart)
if in_chat:
yield 'data:' + orjson.dumps(
{'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n'
else:
data = []
_fields = {}
if chart.get('columns'):
for _column in chart.get('columns'):
if _column:
_fields[_column.get('value')] = _column.get('name')
if chart.get('axis'):
if chart.get('axis').get('x'):
_fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name')
if chart.get('axis').get('y'):
_fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name')
if chart.get('axis').get('series'):
_fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get(
'name')
_fields_list = []
_fields_skip = False
for _data in result.get('data'):
_row = []
for field in result.get('fields'):
_row.append(_data.get(field))
if not _fields_skip:
_fields_list.append(field if not _fields.get(field) else _fields.get(field))
data.append(_row)
_fields_skip = True
df = pd.DataFrame(np.array(data), columns=_fields_list)
markdown_table = df.to_markdown(index=False)
yield markdown_table + '\n\n'
record = self.finish()
if in_chat:
yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
else:
# todo generate picture
if chart['type'] != 'table':
yield '### generated chart picture\n\n'
image_url = request_picture(self.record.chat_id, self.record.id, chart, result)
SQLBotLogUtil.info(image_url)
yield f'![{chart["type"]}]({image_url})'
except Exception as e:
error_msg: str
if isinstance(e, SingleMessageError):
error_msg = str(e)
elif isinstance(e, SQLBotDBConnectionError):
error_msg = orjson.dumps(
{'message': str(e), 'type': 'db-connection-err'}).decode()
elif isinstance(e, SQLBotDBError):
error_msg = orjson.dumps(
{'message': 'Execute SQL Failed', 'traceback': str(e), 'type': 'exec-sql-err'}).decode()
else:
error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode()
self.save_error(message=error_msg)
if in_chat:
yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n'
else:
yield f'> &#x274c; **ERROR**\n\n> \n\n> {error_msg}'
def run_recommend_questions_task_async(self):
self.future = executor.submit(self.run_recommend_questions_task_cache)
def run_recommend_questions_task_cache(self):
for chunk in self.run_recommend_questions_task():
self.chunk_list.append(chunk)
def run_recommend_questions_task(self):
res = self.generate_recommend_questions_task()
for chunk in res:
if chunk.get('recommended_question'):
yield 'data:' + orjson.dumps(
{'content': chunk.get('recommended_question'), 'type': 'recommended_question'}).decode() + '\n\n'
else:
yield 'data:' + orjson.dumps(
{'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'),
'type': 'recommended_question_result'}).decode() + '\n\n'
def run_analysis_or_predict_task_async(self, action_type: str, base_record: ChatRecord):
self.set_record(save_analysis_predict_record(self.session, base_record, action_type))
self.future = executor.submit(self.run_analysis_or_predict_task_cache, action_type)
def run_analysis_or_predict_task_cache(self, action_type: str):
for chunk in self.run_analysis_or_predict_task(action_type):
self.chunk_list.append(chunk)
def run_analysis_or_predict_task(self, action_type: str):
try:
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'
if action_type == 'analysis':
# generate analysis
analysis_res = self.generate_analysis()
for chunk in analysis_res:
yield 'data:' + orjson.dumps(
{'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'),
'type': 'analysis-result'}).decode() + '\n\n'
yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n'
yield 'data:' + orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n'
elif action_type == 'predict':
# generate predict
analysis_res = self.generate_predict()
full_text = ''
for chunk in analysis_res:
yield 'data:' + orjson.dumps(
{'content': chunk.get('content'), 'reasoning_content': chunk.get('reasoning_content'),
'type': 'predict-result'}).decode() + '\n\n'
full_text += chunk.get('content')
yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n'
_data = self.check_save_predict_data(res=full_text)
if _data:
yield 'data:' + orjson.dumps({'type': 'predict-success'}).decode() + '\n\n'
else:
yield 'data:' + orjson.dumps({'type': 'predict-failed'}).decode() + '\n\n'
yield 'data:' + orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n'
self.finish()
except Exception as e:
error_msg: str
if isinstance(e, SingleMessageError):
error_msg = str(e)
else:
error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode()
self.save_error(message=error_msg)
yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n'
finally:
# end
pass
def validate_history_ds(self):
_ds = self.ds
if not self.current_assistant or self.current_assistant.type == 4:
try:
current_ds = self.session.get(CoreDatasource, _ds.id)
if not current_ds:
raise SingleMessageError('chat.ds_is_invalid')
except Exception as e:
raise SingleMessageError("chat.ds_is_invalid")
else:
try:
_ds_list: list[dict] = get_assistant_ds(session=self.session, llm_service=self)
match_ds = any(item.get("id") == _ds.id for item in _ds_list)
if not match_ds:
type = self.current_assistant.type
msg = f"[please check ds list and public ds list]" if type == 0 else f"[please check ds api]"
raise SingleMessageError(msg)
except Exception as e:
raise SingleMessageError(f"ds is invalid [{str(e)}]")
def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
"""Execute SQL query using SQLDatabase
Args:
db: SQLDatabase instance
sql: SQL query statement
Returns:
str: Query results formatted as string
"""
try:
# Execute query
result = db.run(sql)
if not result:
return "Query executed successfully but returned no results."
# Format results
return str(result)
except Exception as e:
error_msg = f"SQL execution failed: {str(e)}"
SQLBotLogUtil.exception(error_msg)
raise RuntimeError(error_msg)
def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
file_name = f'c_{chat_id}_r_{record_id}'
columns = chart.get('columns') if chart.get('columns') else []
x = None
y = None
series = None
if chart.get('axis'):
x = chart.get('axis').get('x')
y = chart.get('axis').get('y')
series = chart.get('axis').get('series')
axis = []
for v in columns:
axis.append({'name': v.get('name'), 'value': v.get('value')})
if x:
axis.append({'name': x.get('name'), 'value': x.get('value'), 'type': 'x'})
if y:
axis.append({'name': y.get('name'), 'value': y.get('value'), 'type': 'y'})
if series:
axis.append({'name': series.get('name'), 'value': series.get('value'), 'type': 'series'})
request_obj = {
"path": os.path.join(settings.MCP_IMAGE_PATH, file_name),
"type": chart['type'],
"data": orjson.dumps(data.get('data') if data.get('data') else []).decode(),
"axis": orjson.dumps(axis).decode(),
}
requests.post(url=settings.MCP_IMAGE_HOST, json=request_obj)
request_path = urllib.parse.urljoin(settings.MCP_IMAGE_HOST, f"{file_name}.png")
return request_path
def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
try:
if chunk.usage_metadata:
token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens')
token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens')
token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens')
except Exception:
pass
def get_lang_name(lang: str):
if lang and lang == 'en':
return '英文'
return '简体中文'