This commit is contained in:
2025-09-08 16:36:00 +08:00
parent dda4bf445c
commit bd3adc2556

View File

@@ -0,0 +1,258 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from fastapi import Body
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean
from sqlalchemy import Enum as SQLAlchemyEnum
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import SQLModel, Field
from apps.template.filter.generator import get_permissions_template
from apps.template.generate_analysis.generator import get_analysis_template
from apps.template.generate_chart.generator import get_chart_template
from apps.template.generate_dynamic.generator import get_dynamic_template
from apps.template.generate_guess_question.generator import get_guess_question_template
from apps.template.generate_predict.generator import get_predict_template
from apps.template.generate_sql.generator import get_sql_template
from apps.template.select_datasource.generator import get_datasource_template
def enum_values(enum_class: type[Enum]) -> list:
"""Get values for enum."""
return [status.value for status in enum_class]
class TypeEnum(Enum):
CHAT = "0"
# TODO other usage
class OperationEnum(Enum):
GENERATE_SQL = '0'
GENERATE_CHART = '1'
ANALYSIS = '2'
PREDICT_DATA = '3'
GENERATE_RECOMMENDED_QUESTIONS = '4'
GENERATE_SQL_WITH_PERMISSIONS = '5'
CHOOSE_DATASOURCE = '6'
GENERATE_DYNAMIC_SQL = '7'
# TODO choose table / check connection / generate description
class ChatLog(SQLModel, table=True):
__tablename__ = "chat_log"
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
type: TypeEnum = Field(
sa_column=Column(SQLAlchemyEnum(TypeEnum, native_enum=False, values_callable=enum_values, length=3)))
operate: OperationEnum = Field(
sa_column=Column(SQLAlchemyEnum(OperationEnum, native_enum=False, values_callable=enum_values, length=3)))
pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True))
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
base_modal: Optional[str] = Field(max_length=255)
messages: Optional[list[dict]] = Field(sa_column=Column(JSONB))
reasoning_content: Optional[str | None] = Field(sa_column=Column(Text, nullable=True))
start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB))
class Chat(SQLModel, table=True):
__tablename__ = "chat"
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1))
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
brief: str = Field(max_length=64, nullable=True)
chat_type: str = Field(max_length=20, default="chat") # chat, datasource
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
engine_type: str = Field(max_length=64)
origin: Optional[int] = Field(
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
class ChatRecord(SQLModel, table=True):
__tablename__ = "chat_record"
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
engine_type: str = Field(max_length=64, nullable=True)
question: str = Field(sa_column=Column(Text, nullable=True))
sql_answer: str = Field(sa_column=Column(Text, nullable=True))
sql: str = Field(sa_column=Column(Text, nullable=True))
sql_exec_result: str = Field(sa_column=Column(Text, nullable=True))
data: str = Field(sa_column=Column(Text, nullable=True))
chart_answer: str = Field(sa_column=Column(Text, nullable=True))
chart: str = Field(sa_column=Column(Text, nullable=True))
analysis: str = Field(sa_column=Column(Text, nullable=True))
predict: str = Field(sa_column=Column(Text, nullable=True))
predict_data: str = Field(sa_column=Column(Text, nullable=True))
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
error: str = Field(sa_column=Column(Text, nullable=True))
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
class ChatRecordResult(BaseModel):
id: Optional[int] = None
chat_id: Optional[int] = None
ai_modal_id: Optional[int] = None
first_chat: bool = False
create_time: Optional[datetime] = None
finish_time: Optional[datetime] = None
question: Optional[str] = None
sql_answer: Optional[str] = None
sql: Optional[str] = None
data: Optional[str] = None
chart_answer: Optional[str] = None
chart: Optional[str] = None
analysis: Optional[str] = None
predict: Optional[str] = None
predict_data: Optional[str] = None
recommended_question: Optional[str] = None
datasource_select_answer: Optional[str] = None
finish: Optional[bool] = None
error: Optional[str] = None
analysis_record_id: Optional[int] = None
predict_record_id: Optional[int] = None
sql_reasoning_content: Optional[str] = None
chart_reasoning_content: Optional[str] = None
analysis_reasoning_content: Optional[str] = None
predict_reasoning_content: Optional[str] = None
class CreateChat(BaseModel):
id: int = None
question: str = None
datasource: int = None
origin: Optional[int] = 0
class RenameChat(BaseModel):
id: int = None
brief: str = ''
class ChatInfo(BaseModel):
id: Optional[int] = None
create_time: datetime = None
create_by: int = None
brief: str = ''
chat_type: str = "chat"
datasource: Optional[int] = None
engine_type: str = ''
ds_type: str = ''
datasource_name: str = ''
datasource_exists: bool = True
records: List[ChatRecord | dict] = []
class AiModelQuestion(BaseModel):
question: str = None
ai_modal_id: int = None
ai_modal_name: str = None # Specific model name
engine: str = ""
db_schema: str = ""
sql: str = ""
rule: str = ""
fields: str = ""
data: str = ""
lang: str = "简体中文"
filter: str = []
sub_query: Optional[list[dict]] = None
terminologies: str = ""
error_msg: str = ""
def sql_sys_question(self):
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
lang=self.lang, terminologies=self.terminologies)
def sql_user_question(self, current_time: str):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
rule=self.rule, current_time=current_time, error_msg=self.error_msg)
def chart_sys_question(self):
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
def chart_user_question(self, chart_type: Optional[str] = None):
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
chart_type=chart_type)
def analysis_sys_question(self):
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies)
def analysis_user_question(self):
return get_analysis_template()['user'].format(fields=self.fields, data=self.data)
def predict_sys_question(self):
return get_predict_template()['system'].format(lang=self.lang)
def predict_user_question(self):
return get_predict_template()['user'].format(fields=self.fields, data=self.data)
def datasource_sys_question(self):
return get_datasource_template()['system'].format(lang=self.lang)
def datasource_user_question(self, datasource_list: str = "[]"):
return get_datasource_template()['user'].format(question=self.question, data=datasource_list)
def guess_sys_question(self):
return get_guess_question_template()['system'].format(lang=self.lang)
def guess_user_question(self, old_questions: str = "[]"):
return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema,
old_questions=old_questions)
def filter_sys_question(self):
return get_permissions_template()['system'].format(lang=self.lang, engine=self.engine)
def filter_user_question(self):
return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter)
def dynamic_sys_question(self):
return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine)
def dynamic_user_question(self):
return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query)
class ChatQuestion(AiModelQuestion):
chat_id: int
class ChatMcp(ChatQuestion):
token: str
class ChatStart(BaseModel):
username: str = Body(description='用户名')
password: str = Body(description='密码')
class McpQuestion(BaseModel):
question: str = Body(description='用户提问')
chat_id: int = Body(description='会话ID')
token: str = Body(description='token')
class AxisObj(BaseModel):
name: str = ''
value: str = ''
type: str | None = None
class ExcelData(BaseModel):
axis: list[AxisObj] = []
data: list[dict] = []
name: str = 'Excel'