This commit is contained in:
2025-09-08 16:36:11 +08:00
parent 69b9006534
commit 6e8eb7bbc0

View File

@@ -0,0 +1,190 @@
from datetime import timedelta
import json
import os
from typing import List, Optional
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response, UploadFile
from fastapi.responses import StreamingResponse
from sqlmodel import select
from apps.system.crud.assistant import get_assistant_info
from apps.system.crud.assistant_manage import dynamic_upgrade_cors, save
from apps.system.models.system_model import AssistantModel
from apps.system.schemas.auth import CacheName, CacheNamespace
from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator
from common.core.deps import SessionDep, Trans
from common.core.security import create_access_token
from common.core.sqlbot_cache import clear_cache
from common.utils.time import get_timestamp
from common.core.config import settings
from common.utils.utils import get_origin_from_referer
from sqlbot_xpack.file_utils import SQLBotFileUtils
router = APIRouter(tags=["system/assistant"], prefix="/system/assistant")
@router.get("/info/{id}")
async def info(request: Request, response: Response, session: SessionDep, trans: Trans, id: int) -> AssistantModel:
if not id:
raise Exception('miss assistant id')
db_model = await get_assistant_info(session=session, assistant_id=id)
if not db_model:
raise RuntimeError(f"assistant application not exist")
db_model = AssistantModel.model_validate(db_model)
response.headers["Access-Control-Allow-Origin"] = db_model.domain
origin = request.headers.get("origin") or get_origin_from_referer(request)
if not origin:
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
origin = origin.rstrip('/')
if origin != db_model.domain:
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
return db_model
@router.get("/app/{appId}")
async def getApp(request: Request, response: Response, session: SessionDep, trans: Trans, appId: str) -> AssistantModel:
if not appId:
raise Exception('miss assistant appId')
db_model = session.exec(select(AssistantModel).where(AssistantModel.app_id == appId)).first()
if not db_model:
raise RuntimeError(f"assistant application not exist")
db_model = AssistantModel.model_validate(db_model)
response.headers["Access-Control-Allow-Origin"] = db_model.domain
origin = request.headers.get("origin") or get_origin_from_referer(request)
if not origin:
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
origin = origin.rstrip('/')
if origin != db_model.domain:
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
return db_model
@router.get("/validator", response_model=AssistantValidator)
async def validator(session: SessionDep, id: int, virtual: Optional[int] = Query(None)):
if not id:
raise Exception('miss assistant id')
db_model = await get_assistant_info(session=session, assistant_id=id)
if not db_model:
return AssistantValidator()
db_model = AssistantModel.model_validate(db_model)
assistant_oid = 1
if(db_model.type == 0):
configuration = db_model.configuration
config_obj = json.loads(configuration) if configuration else {}
assistant_oid = config_obj.get('oid', 1)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
assistantDict = {
"id": virtual, "account": 'sqlbot-inner-assistant', "oid": assistant_oid, "assistant_id": id
}
access_token = create_access_token(
assistantDict, expires_delta=access_token_expires
)
return AssistantValidator(True, True, True, access_token)
@router.get('/picture/{file_id}')
async def picture(file_id: str):
file_path = SQLBotFileUtils.get_file_path(file_id=file_id)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found")
if file_id.lower().endswith(".svg"):
media_type = "image/svg+xml"
else:
media_type = "image/jpeg"
def iterfile():
with open(file_path, mode="rb") as f:
yield from f
return StreamingResponse(iterfile(), media_type=media_type)
@router.patch('/ui')
async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] = []):
json_data = json.loads(data)
uiSchema = AssistantUiSchema(**json_data)
id = uiSchema.id
db_model = session.get(AssistantModel, id)
if not db_model:
raise ValueError(f"AssistantModel with id {id} not found")
configuration = db_model.configuration
config_obj = json.loads(configuration) if configuration else {}
ui_schema_dict = uiSchema.model_dump(exclude_none=True, exclude_unset=True)
if files:
for file in files:
origin_file_name = file.filename
file_name, flag_name = SQLBotFileUtils.split_filename_and_flag(origin_file_name)
file.filename = file_name
if flag_name == 'logo' or flag_name == 'float_icon':
SQLBotFileUtils.check_file(file=file, file_types=[".jpg", ".jpeg", ".png", ".svg"], limit_file_size=(10 * 1024 * 1024))
if config_obj.get(flag_name):
SQLBotFileUtils.detete_file(config_obj.get(flag_name))
file_id = await SQLBotFileUtils.upload(file)
ui_schema_dict[flag_name] = file_id
else:
raise ValueError(f"Unsupported file flag: {flag_name}")
for flag_name in ['logo', 'float_icon']:
file_val = config_obj.get(flag_name)
if file_val and not ui_schema_dict.get(flag_name):
config_obj[flag_name] = None
SQLBotFileUtils.detete_file(file_val)
for attr, value in ui_schema_dict.items():
if attr != 'id' and not attr.startswith("__"):
config_obj[attr] = value
db_model.configuration = json.dumps(config_obj, ensure_ascii=False)
session.add(db_model)
session.commit()
await clear_ui_cache(db_model.id)
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id")
async def clear_ui_cache(id: int):
pass
@router.get("", response_model=list[AssistantModel])
async def query(session: SessionDep):
list_result = session.exec(select(AssistantModel).where(AssistantModel.type != 4).order_by(AssistantModel.name, AssistantModel.create_time)).all()
return list_result
@router.post("")
async def add(request: Request, session: SessionDep, creator: AssistantBase):
await save(request, session, creator)
@router.put("")
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="editor.id")
async def update(request: Request, session: SessionDep, editor: AssistantDTO):
id = editor.id
db_model = session.get(AssistantModel, id)
if not db_model:
raise ValueError(f"AssistantModel with id {id} not found")
update_data = AssistantModel.model_validate(editor)
db_model.sqlmodel_update(update_data)
session.add(db_model)
session.commit()
dynamic_upgrade_cors(request=request, session=session)
@router.get("/{id}", response_model=AssistantModel)
async def get_one(session: SessionDep, id: int):
db_model = await get_assistant_info(session=session, assistant_id=id)
if not db_model:
raise ValueError(f"AssistantModel with id {id} not found")
db_model = AssistantModel.model_validate(db_model)
return db_model
@router.delete("/{id}")
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id")
async def delete(request: Request, session: SessionDep, id: int):
db_model = session.get(AssistantModel, id)
if not db_model:
raise ValueError(f"AssistantModel with id {id} not found")
session.delete(db_model)
session.commit()
dynamic_upgrade_cors(request=request, session=session)