191 lines
7.9 KiB
Python
191 lines
7.9 KiB
Python
|
|
from datetime import timedelta
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from typing import List, Optional
|
||
|
|
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response, UploadFile
|
||
|
|
from fastapi.responses import StreamingResponse
|
||
|
|
from sqlmodel import select
|
||
|
|
from apps.system.crud.assistant import get_assistant_info
|
||
|
|
from apps.system.crud.assistant_manage import dynamic_upgrade_cors, save
|
||
|
|
from apps.system.models.system_model import AssistantModel
|
||
|
|
from apps.system.schemas.auth import CacheName, CacheNamespace
|
||
|
|
from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator
|
||
|
|
from common.core.deps import SessionDep, Trans
|
||
|
|
from common.core.security import create_access_token
|
||
|
|
from common.core.sqlbot_cache import clear_cache
|
||
|
|
from common.utils.time import get_timestamp
|
||
|
|
|
||
|
|
from common.core.config import settings
|
||
|
|
from common.utils.utils import get_origin_from_referer
|
||
|
|
from sqlbot_xpack.file_utils import SQLBotFileUtils
|
||
|
|
|
||
|
|
router = APIRouter(tags=["system/assistant"], prefix="/system/assistant")
|
||
|
|
|
||
|
|
@router.get("/info/{id}")
|
||
|
|
async def info(request: Request, response: Response, session: SessionDep, trans: Trans, id: int) -> AssistantModel:
|
||
|
|
if not id:
|
||
|
|
raise Exception('miss assistant id')
|
||
|
|
db_model = await get_assistant_info(session=session, assistant_id=id)
|
||
|
|
if not db_model:
|
||
|
|
raise RuntimeError(f"assistant application not exist")
|
||
|
|
db_model = AssistantModel.model_validate(db_model)
|
||
|
|
response.headers["Access-Control-Allow-Origin"] = db_model.domain
|
||
|
|
origin = request.headers.get("origin") or get_origin_from_referer(request)
|
||
|
|
if not origin:
|
||
|
|
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
|
||
|
|
origin = origin.rstrip('/')
|
||
|
|
if origin != db_model.domain:
|
||
|
|
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
|
||
|
|
return db_model
|
||
|
|
|
||
|
|
@router.get("/app/{appId}")
|
||
|
|
async def getApp(request: Request, response: Response, session: SessionDep, trans: Trans, appId: str) -> AssistantModel:
|
||
|
|
if not appId:
|
||
|
|
raise Exception('miss assistant appId')
|
||
|
|
db_model = session.exec(select(AssistantModel).where(AssistantModel.app_id == appId)).first()
|
||
|
|
if not db_model:
|
||
|
|
raise RuntimeError(f"assistant application not exist")
|
||
|
|
db_model = AssistantModel.model_validate(db_model)
|
||
|
|
response.headers["Access-Control-Allow-Origin"] = db_model.domain
|
||
|
|
origin = request.headers.get("origin") or get_origin_from_referer(request)
|
||
|
|
if not origin:
|
||
|
|
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
|
||
|
|
origin = origin.rstrip('/')
|
||
|
|
if origin != db_model.domain:
|
||
|
|
raise RuntimeError(trans('i18n_embedded.invalid_origin', origin = origin or ''))
|
||
|
|
return db_model
|
||
|
|
|
||
|
|
@router.get("/validator", response_model=AssistantValidator)
|
||
|
|
async def validator(session: SessionDep, id: int, virtual: Optional[int] = Query(None)):
|
||
|
|
if not id:
|
||
|
|
raise Exception('miss assistant id')
|
||
|
|
|
||
|
|
db_model = await get_assistant_info(session=session, assistant_id=id)
|
||
|
|
if not db_model:
|
||
|
|
return AssistantValidator()
|
||
|
|
db_model = AssistantModel.model_validate(db_model)
|
||
|
|
assistant_oid = 1
|
||
|
|
if(db_model.type == 0):
|
||
|
|
configuration = db_model.configuration
|
||
|
|
config_obj = json.loads(configuration) if configuration else {}
|
||
|
|
assistant_oid = config_obj.get('oid', 1)
|
||
|
|
|
||
|
|
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
|
|
assistantDict = {
|
||
|
|
"id": virtual, "account": 'sqlbot-inner-assistant', "oid": assistant_oid, "assistant_id": id
|
||
|
|
}
|
||
|
|
access_token = create_access_token(
|
||
|
|
assistantDict, expires_delta=access_token_expires
|
||
|
|
)
|
||
|
|
return AssistantValidator(True, True, True, access_token)
|
||
|
|
|
||
|
|
@router.get('/picture/{file_id}')
|
||
|
|
async def picture(file_id: str):
|
||
|
|
file_path = SQLBotFileUtils.get_file_path(file_id=file_id)
|
||
|
|
if not os.path.exists(file_path):
|
||
|
|
raise HTTPException(status_code=404, detail="File not found")
|
||
|
|
|
||
|
|
if file_id.lower().endswith(".svg"):
|
||
|
|
media_type = "image/svg+xml"
|
||
|
|
else:
|
||
|
|
media_type = "image/jpeg"
|
||
|
|
|
||
|
|
def iterfile():
|
||
|
|
with open(file_path, mode="rb") as f:
|
||
|
|
yield from f
|
||
|
|
|
||
|
|
return StreamingResponse(iterfile(), media_type=media_type)
|
||
|
|
|
||
|
|
@router.patch('/ui')
|
||
|
|
async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] = []):
|
||
|
|
json_data = json.loads(data)
|
||
|
|
uiSchema = AssistantUiSchema(**json_data)
|
||
|
|
id = uiSchema.id
|
||
|
|
db_model = session.get(AssistantModel, id)
|
||
|
|
if not db_model:
|
||
|
|
raise ValueError(f"AssistantModel with id {id} not found")
|
||
|
|
configuration = db_model.configuration
|
||
|
|
config_obj = json.loads(configuration) if configuration else {}
|
||
|
|
|
||
|
|
ui_schema_dict = uiSchema.model_dump(exclude_none=True, exclude_unset=True)
|
||
|
|
if files:
|
||
|
|
for file in files:
|
||
|
|
origin_file_name = file.filename
|
||
|
|
file_name, flag_name = SQLBotFileUtils.split_filename_and_flag(origin_file_name)
|
||
|
|
file.filename = file_name
|
||
|
|
if flag_name == 'logo' or flag_name == 'float_icon':
|
||
|
|
SQLBotFileUtils.check_file(file=file, file_types=[".jpg", ".jpeg", ".png", ".svg"], limit_file_size=(10 * 1024 * 1024))
|
||
|
|
if config_obj.get(flag_name):
|
||
|
|
SQLBotFileUtils.detete_file(config_obj.get(flag_name))
|
||
|
|
file_id = await SQLBotFileUtils.upload(file)
|
||
|
|
ui_schema_dict[flag_name] = file_id
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported file flag: {flag_name}")
|
||
|
|
|
||
|
|
for flag_name in ['logo', 'float_icon']:
|
||
|
|
file_val = config_obj.get(flag_name)
|
||
|
|
if file_val and not ui_schema_dict.get(flag_name):
|
||
|
|
config_obj[flag_name] = None
|
||
|
|
SQLBotFileUtils.detete_file(file_val)
|
||
|
|
|
||
|
|
for attr, value in ui_schema_dict.items():
|
||
|
|
if attr != 'id' and not attr.startswith("__"):
|
||
|
|
config_obj[attr] = value
|
||
|
|
|
||
|
|
db_model.configuration = json.dumps(config_obj, ensure_ascii=False)
|
||
|
|
session.add(db_model)
|
||
|
|
session.commit()
|
||
|
|
await clear_ui_cache(db_model.id)
|
||
|
|
|
||
|
|
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id")
|
||
|
|
async def clear_ui_cache(id: int):
|
||
|
|
pass
|
||
|
|
|
||
|
|
@router.get("", response_model=list[AssistantModel])
|
||
|
|
async def query(session: SessionDep):
|
||
|
|
list_result = session.exec(select(AssistantModel).where(AssistantModel.type != 4).order_by(AssistantModel.name, AssistantModel.create_time)).all()
|
||
|
|
return list_result
|
||
|
|
|
||
|
|
@router.post("")
|
||
|
|
async def add(request: Request, session: SessionDep, creator: AssistantBase):
|
||
|
|
await save(request, session, creator)
|
||
|
|
|
||
|
|
|
||
|
|
@router.put("")
|
||
|
|
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="editor.id")
|
||
|
|
async def update(request: Request, session: SessionDep, editor: AssistantDTO):
|
||
|
|
id = editor.id
|
||
|
|
db_model = session.get(AssistantModel, id)
|
||
|
|
if not db_model:
|
||
|
|
raise ValueError(f"AssistantModel with id {id} not found")
|
||
|
|
update_data = AssistantModel.model_validate(editor)
|
||
|
|
db_model.sqlmodel_update(update_data)
|
||
|
|
session.add(db_model)
|
||
|
|
session.commit()
|
||
|
|
dynamic_upgrade_cors(request=request, session=session)
|
||
|
|
|
||
|
|
@router.get("/{id}", response_model=AssistantModel)
|
||
|
|
async def get_one(session: SessionDep, id: int):
|
||
|
|
db_model = await get_assistant_info(session=session, assistant_id=id)
|
||
|
|
if not db_model:
|
||
|
|
raise ValueError(f"AssistantModel with id {id} not found")
|
||
|
|
db_model = AssistantModel.model_validate(db_model)
|
||
|
|
return db_model
|
||
|
|
|
||
|
|
@router.delete("/{id}")
|
||
|
|
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id")
|
||
|
|
async def delete(request: Request, session: SessionDep, id: int):
|
||
|
|
db_model = session.get(AssistantModel, id)
|
||
|
|
if not db_model:
|
||
|
|
raise ValueError(f"AssistantModel with id {id} not found")
|
||
|
|
session.delete(db_model)
|
||
|
|
session.commit()
|
||
|
|
dynamic_upgrade_cors(request=request, session=session)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|