diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py new file mode 100644 index 0000000..eb41825 --- /dev/null +++ b/backend/apps/chat/models/chat_model.py @@ -0,0 +1,258 @@ +from datetime import datetime +from enum import Enum +from typing import List, Optional + +from fastapi import Body +from pydantic import BaseModel +from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean +from sqlalchemy import Enum as SQLAlchemyEnum +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import SQLModel, Field + +from apps.template.filter.generator import get_permissions_template +from apps.template.generate_analysis.generator import get_analysis_template +from apps.template.generate_chart.generator import get_chart_template +from apps.template.generate_dynamic.generator import get_dynamic_template +from apps.template.generate_guess_question.generator import get_guess_question_template +from apps.template.generate_predict.generator import get_predict_template +from apps.template.generate_sql.generator import get_sql_template +from apps.template.select_datasource.generator import get_datasource_template + + +def enum_values(enum_class: type[Enum]) -> list: + """Get values for enum.""" + return [status.value for status in enum_class] + + +class TypeEnum(Enum): + CHAT = "0" + + +# TODO other usage + +class OperationEnum(Enum): + GENERATE_SQL = '0' + GENERATE_CHART = '1' + ANALYSIS = '2' + PREDICT_DATA = '3' + GENERATE_RECOMMENDED_QUESTIONS = '4' + GENERATE_SQL_WITH_PERMISSIONS = '5' + CHOOSE_DATASOURCE = '6' + GENERATE_DYNAMIC_SQL = '7' + + +# TODO choose table / check connection / generate description + +class ChatLog(SQLModel, table=True): + __tablename__ = "chat_log" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + type: TypeEnum = Field( + sa_column=Column(SQLAlchemyEnum(TypeEnum, native_enum=False, values_callable=enum_values, length=3))) + operate: OperationEnum = Field( + sa_column=Column(SQLAlchemyEnum(OperationEnum, native_enum=False, values_callable=enum_values, length=3))) + pid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True)) + ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger)) + base_modal: Optional[str] = Field(max_length=255) + messages: Optional[list[dict]] = Field(sa_column=Column(JSONB)) + reasoning_content: Optional[str | None] = Field(sa_column=Column(Text, nullable=True)) + start_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + token_usage: Optional[dict | None | int] = Field(sa_column=Column(JSONB)) + + +class Chat(SQLModel, table=True): + __tablename__ = "chat" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1)) + create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + create_by: int = Field(sa_column=Column(BigInteger, nullable=True)) + brief: str = Field(max_length=64, nullable=True) + chat_type: str = Field(max_length=20, default="chat") # chat, datasource + datasource: int = Field(sa_column=Column(BigInteger, nullable=True)) + engine_type: str = Field(max_length=64) + origin: Optional[int] = Field( + sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant + + +class ChatRecord(SQLModel, table=True): + __tablename__ = "chat_record" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + chat_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + ai_modal_id: Optional[int] = Field(sa_column=Column(BigInteger)) + first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False)) + create_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + finish_time: datetime = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + create_by: int = Field(sa_column=Column(BigInteger, nullable=True)) + datasource: int = Field(sa_column=Column(BigInteger, nullable=True)) + engine_type: str = Field(max_length=64, nullable=True) + question: str = Field(sa_column=Column(Text, nullable=True)) + sql_answer: str = Field(sa_column=Column(Text, nullable=True)) + sql: str = Field(sa_column=Column(Text, nullable=True)) + sql_exec_result: str = Field(sa_column=Column(Text, nullable=True)) + data: str = Field(sa_column=Column(Text, nullable=True)) + chart_answer: str = Field(sa_column=Column(Text, nullable=True)) + chart: str = Field(sa_column=Column(Text, nullable=True)) + analysis: str = Field(sa_column=Column(Text, nullable=True)) + predict: str = Field(sa_column=Column(Text, nullable=True)) + predict_data: str = Field(sa_column=Column(Text, nullable=True)) + recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True)) + recommended_question: str = Field(sa_column=Column(Text, nullable=True)) + datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True)) + finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False)) + error: str = Field(sa_column=Column(Text, nullable=True)) + analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) + predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True)) + + +class ChatRecordResult(BaseModel): + id: Optional[int] = None + chat_id: Optional[int] = None + ai_modal_id: Optional[int] = None + first_chat: bool = False + create_time: Optional[datetime] = None + finish_time: Optional[datetime] = None + question: Optional[str] = None + sql_answer: Optional[str] = None + sql: Optional[str] = None + data: Optional[str] = None + chart_answer: Optional[str] = None + chart: Optional[str] = None + analysis: Optional[str] = None + predict: Optional[str] = None + predict_data: Optional[str] = None + recommended_question: Optional[str] = None + datasource_select_answer: Optional[str] = None + finish: Optional[bool] = None + error: Optional[str] = None + analysis_record_id: Optional[int] = None + predict_record_id: Optional[int] = None + sql_reasoning_content: Optional[str] = None + chart_reasoning_content: Optional[str] = None + analysis_reasoning_content: Optional[str] = None + predict_reasoning_content: Optional[str] = None + + +class CreateChat(BaseModel): + id: int = None + question: str = None + datasource: int = None + origin: Optional[int] = 0 + + +class RenameChat(BaseModel): + id: int = None + brief: str = '' + + +class ChatInfo(BaseModel): + id: Optional[int] = None + create_time: datetime = None + create_by: int = None + brief: str = '' + chat_type: str = "chat" + datasource: Optional[int] = None + engine_type: str = '' + ds_type: str = '' + datasource_name: str = '' + datasource_exists: bool = True + records: List[ChatRecord | dict] = [] + + +class AiModelQuestion(BaseModel): + question: str = None + ai_modal_id: int = None + ai_modal_name: str = None # Specific model name + engine: str = "" + db_schema: str = "" + sql: str = "" + rule: str = "" + fields: str = "" + data: str = "" + lang: str = "简体中文" + filter: str = [] + sub_query: Optional[list[dict]] = None + terminologies: str = "" + error_msg: str = "" + + def sql_sys_question(self): + return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, + lang=self.lang, terminologies=self.terminologies) + + def sql_user_question(self, current_time: str): + return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, + rule=self.rule, current_time=current_time, error_msg=self.error_msg) + + def chart_sys_question(self): + return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) + + def chart_user_question(self, chart_type: Optional[str] = None): + return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule, + chart_type=chart_type) + + def analysis_sys_question(self): + return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies) + + def analysis_user_question(self): + return get_analysis_template()['user'].format(fields=self.fields, data=self.data) + + def predict_sys_question(self): + return get_predict_template()['system'].format(lang=self.lang) + + def predict_user_question(self): + return get_predict_template()['user'].format(fields=self.fields, data=self.data) + + def datasource_sys_question(self): + return get_datasource_template()['system'].format(lang=self.lang) + + def datasource_user_question(self, datasource_list: str = "[]"): + return get_datasource_template()['user'].format(question=self.question, data=datasource_list) + + def guess_sys_question(self): + return get_guess_question_template()['system'].format(lang=self.lang) + + def guess_user_question(self, old_questions: str = "[]"): + return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema, + old_questions=old_questions) + + def filter_sys_question(self): + return get_permissions_template()['system'].format(lang=self.lang, engine=self.engine) + + def filter_user_question(self): + return get_permissions_template()['user'].format(sql=self.sql, filter=self.filter) + + def dynamic_sys_question(self): + return get_dynamic_template()['system'].format(lang=self.lang, engine=self.engine) + + def dynamic_user_question(self): + return get_dynamic_template()['user'].format(sql=self.sql, sub_query=self.sub_query) + + +class ChatQuestion(AiModelQuestion): + chat_id: int + + +class ChatMcp(ChatQuestion): + token: str + + +class ChatStart(BaseModel): + username: str = Body(description='用户名') + password: str = Body(description='密码') + + +class McpQuestion(BaseModel): + question: str = Body(description='用户提问') + chat_id: int = Body(description='会话ID') + token: str = Body(description='token') + + +class AxisObj(BaseModel): + name: str = '' + value: str = '' + type: str | None = None + + +class ExcelData(BaseModel): + axis: list[AxisObj] = [] + data: list[dict] = [] + name: str = 'Excel'