Add File
This commit is contained in:
258
backend/apps/chat/models/chat_model.py
Normal file
258
backend/apps/chat/models/chat_model.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import Body
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean
|
||||||
|
from sqlalchemy import Enum as SQLAlchemyEnum
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlmodel import SQLModel, Field
|
||||||
|
|
||||||
|
from apps.template.filter.generator import get_permissions_template
|
||||||
|
from apps.template.generate_analysis.generator import get_analysis_template
|
||||||
|
from apps.template.generate_chart.generator import get_chart_template
|
||||||
|
from apps.template.generate_dynamic.generator import get_dynamic_template
|
||||||
|
from apps.template.generate_guess_question.generator import get_guess_question_template
|
||||||
|
from apps.template.generate_predict.generator import get_predict_template
|
||||||
|
from apps.template.generate_sql.generator import get_sql_template
|
||||||
|
from apps.template.select_datasource.generator import get_datasource_template
|
||||||
|
|
||||||
|
|
||||||
|
def enum_values(enum_class: type[Enum]) -> list:
|
||||||
|
"""Get values for enum."""
|
||||||
|
return [status.value for status in enum_class]
|
||||||
|
|
||||||
|
|
||||||
|
class TypeEnum(Enum):
|
||||||
|
CHAT = "0"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO other usage
|
||||||
|
|
||||||
|
class OperationEnum(Enum):
|
||||||
|
GENERATE_SQL = '0'
|
||||||
|
GENERATE_CHART = '1'
|
||||||
|
ANALYSIS = '2'
|
||||||
|
PREDICT_DATA = '3'
|
||||||
|
GENERATE_RECOMMENDED_QUESTIONS = '4'
|
||||||
|
GENERATE_SQL_WITH_PERMISSIONS = '5'
|
||||||
|
CHOOSE_DATASOURCE = '6'
|
||||||
|
GENERATE_DYNAMIC_SQL = '7'
|
||||||
|
|
||||||
|
|
||||||
|
# TODO choose table / check connection / generate description
|
||||||
|
|
||||||
|
class ChatLog(SQLModel, table=True):
|
||||||
|
__tablename__ = "chat_log"
|
||||||
|
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
|
||||||
|
type: TypeEnum = Field(
|
||||||
|
sa_column=Column(SQLAlchemyEnum(TypeEnum, native_enum=False, values_callable=enum_values, length=3)))
|
||||||
|
operate: OperationEnum = Field(
|
||||||
|
sa_column=Column(SQLAlchemyEnum(OperationEnum, native_enum=False, values_callable=enum_values, length=3)))
|
||||||
|
pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
|
||||||
|
base_modal: Optional[str] = Field(max_length=255)
|
||||||
|
messages: Optional[list[dict]] = Field(sa_column=Column(JSONB))
|
||||||
|
reasoning_content: Optional[str | None] = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
|
||||||
|
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
|
||||||
|
token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB))
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(SQLModel, table=True):
|
||||||
|
__tablename__ = "chat"
|
||||||
|
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
|
||||||
|
oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1))
|
||||||
|
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
|
||||||
|
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
brief: str = Field(max_length=64, nullable=True)
|
||||||
|
chat_type: str = Field(max_length=20, default="chat") # chat, datasource
|
||||||
|
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
engine_type: str = Field(max_length=64)
|
||||||
|
origin: Optional[int] = Field(
|
||||||
|
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRecord(SQLModel, table=True):
|
||||||
|
__tablename__ = "chat_record"
|
||||||
|
id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
|
||||||
|
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||||
|
ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger))
|
||||||
|
first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
|
||||||
|
create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
|
||||||
|
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
|
||||||
|
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
engine_type: str = Field(max_length=64, nullable=True)
|
||||||
|
question: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
sql_answer: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
sql: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
sql_exec_result: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
data: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
chart_answer: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
chart: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
analysis: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
predict: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
predict_data: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
|
||||||
|
error: str = Field(sa_column=Column(Text, nullable=True))
|
||||||
|
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRecordResult(BaseModel):
|
||||||
|
id: Optional[int] = None
|
||||||
|
chat_id: Optional[int] = None
|
||||||
|
ai_modal_id: Optional[int] = None
|
||||||
|
first_chat: bool = False
|
||||||
|
create_time: Optional[datetime] = None
|
||||||
|
finish_time: Optional[datetime] = None
|
||||||
|
question: Optional[str] = None
|
||||||
|
sql_answer: Optional[str] = None
|
||||||
|
sql: Optional[str] = None
|
||||||
|
data: Optional[str] = None
|
||||||
|
chart_answer: Optional[str] = None
|
||||||
|
chart: Optional[str] = None
|
||||||
|
analysis: Optional[str] = None
|
||||||
|
predict: Optional[str] = None
|
||||||
|
predict_data: Optional[str] = None
|
||||||
|
recommended_question: Optional[str] = None
|
||||||
|
datasource_select_answer: Optional[str] = None
|
||||||
|
finish: Optional[bool] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
analysis_record_id: Optional[int] = None
|
||||||
|
predict_record_id: Optional[int] = None
|
||||||
|
sql_reasoning_content: Optional[str] = None
|
||||||
|
chart_reasoning_content: Optional[str] = None
|
||||||
|
analysis_reasoning_content: Optional[str] = None
|
||||||
|
predict_reasoning_content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChat(BaseModel):
|
||||||
|
id: int = None
|
||||||
|
question: str = None
|
||||||
|
datasource: int = None
|
||||||
|
origin: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RenameChat(BaseModel):
|
||||||
|
id: int = None
|
||||||
|
brief: str = ''
|
||||||
|
|
||||||
|
|
||||||
|
class ChatInfo(BaseModel):
|
||||||
|
id: Optional[int] = None
|
||||||
|
create_time: datetime = None
|
||||||
|
create_by: int = None
|
||||||
|
brief: str = ''
|
||||||
|
chat_type: str = "chat"
|
||||||
|
datasource: Optional[int] = None
|
||||||
|
engine_type: str = ''
|
||||||
|
ds_type: str = ''
|
||||||
|
datasource_name: str = ''
|
||||||
|
datasource_exists: bool = True
|
||||||
|
records: List[ChatRecord | dict] = []
|
||||||
|
|
||||||
|
|
||||||
|
class AiModelQuestion(BaseModel):
|
||||||
|
question: str = None
|
||||||
|
ai_modal_id: int = None
|
||||||
|
ai_modal_name: str = None # Specific model name
|
||||||
|
engine: str = ""
|
||||||
|
db_schema: str = ""
|
||||||
|
sql: str = ""
|
||||||
|
rule: str = ""
|
||||||
|
fields: str = ""
|
||||||
|
data: str = ""
|
||||||
|
lang: str = "简体中文"
|
||||||
|
filter: str = []
|
||||||
|
sub_query: Optional[list[dict]] = None
|
||||||
|
terminologies: str = ""
|
||||||
|
error_msg: str = ""
|
||||||
|
|
||||||
|
def sql_sys_question(self):
|
||||||
|
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
|
||||||
|
lang=self.lang, terminologies=self.terminologies)
|
||||||
|
|
||||||
|
def sql_user_question(self, current_time: str):
|
||||||
|
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
|
||||||
|
rule=self.rule, current_time=current_time, error_msg=self.error_msg)
|
||||||
|
|
||||||
|
def chart_sys_question(self):
|
||||||
|
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
|
||||||
|
|
||||||
|
def chart_user_question(self, chart_type: Optional[str] = None):
|
||||||
|
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
|
||||||
|
chart_type=chart_type)
|
||||||
|
|
||||||
|
def analysis_sys_question(self):
|
||||||
|
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies)
|
||||||
|
|
||||||
|
def analysis_user_question(self):
|
||||||
|
return get_analysis_template()['user'].format(fields=self.fields, data=self.data)
|
||||||
|
|
||||||
|
def predict_sys_question(self):
|
||||||
|
return get_predict_template()['system'].format(lang=self.lang)
|
||||||
|
|
||||||
|
def predict_user_question(self):
|
||||||
|
return get_predict_template()['user'].format(fields=self.fields, data=self.data)
|
||||||
|
|
||||||
|
def datasource_sys_question(self):
|
||||||
|
return get_datasource_template()['system'].format(lang=self.lang)
|
||||||
|
|
||||||
|
def datasource_user_question(self, datasource_list: str = "[]"):
|
||||||
|
return get_datasource_template()['user'].format(question=self.question, data=datasource_list)
|
||||||
|
|
||||||
|
def guess_sys_question(self):
|
||||||
|
return get_guess_question_template()['system'].format(lang=self.lang)
|
||||||
|
|
||||||
|
def guess_user_question(self, old_questions: str = "[]"):
|
||||||
|
return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema,
|
||||||
|
old_questions=old_questions)
|
||||||
|
|
||||||
|
def filter_sys_question(self):
|
||||||
|
return get_permissions_template()['system'].format(lang=self.lang, engine=self.engine)
|
||||||
|
|
||||||
|
def filter_user_question(self):
|
||||||
|
return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter)
|
||||||
|
|
||||||
|
def dynamic_sys_question(self):
|
||||||
|
return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine)
|
||||||
|
|
||||||
|
def dynamic_user_question(self):
|
||||||
|
return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatQuestion(AiModelQuestion):
|
||||||
|
chat_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMcp(ChatQuestion):
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatStart(BaseModel):
|
||||||
|
username: str = Body(description='用户名')
|
||||||
|
password: str = Body(description='密码')
|
||||||
|
|
||||||
|
|
||||||
|
class McpQuestion(BaseModel):
|
||||||
|
question: str = Body(description='用户提问')
|
||||||
|
chat_id: int = Body(description='会话ID')
|
||||||
|
token: str = Body(description='token')
|
||||||
|
|
||||||
|
|
||||||
|
class AxisObj(BaseModel):
|
||||||
|
name: str = ''
|
||||||
|
value: str = ''
|
||||||
|
type: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelData(BaseModel):
|
||||||
|
axis: list[AxisObj] = []
|
||||||
|
data: list[dict] = []
|
||||||
|
name: str = 'Excel'
|
||||||
Reference in New Issue
Block a user