This commit is contained in:
2025-09-08 16:36:12 +08:00
parent 89fe220113
commit 9ef0309062

View File

@@ -0,0 +1,178 @@
import base64
import json
from typing import Optional
from fastapi import Request
from fastapi.responses import JSONResponse
import jwt
from sqlmodel import Session
from starlette.middleware.base import BaseHTTPMiddleware
from apps.system.models.system_model import AssistantModel
from common.core.db import engine
from apps.system.crud.assistant import get_assistant_info, get_assistant_user
from apps.system.crud.user import get_user_by_account, get_user_info
from apps.system.schemas.system_schema import AssistantHeader, UserInfoDTO
from common.core import security
from common.core.config import settings
from common.core.schemas import TokenPayload
from common.utils.locale import I18n
from common.utils.utils import SQLBotLogUtil
from common.utils.whitelist import whiteUtils
from fastapi.security.utils import get_authorization_scheme_param
from common.core.deps import get_i18n
class TokenMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request, call_next):
if self.is_options(request) or whiteUtils.is_whitelisted(request.url.path):
return await call_next(request)
assistantTokenKey = settings.ASSISTANT_TOKEN_KEY
assistantToken = request.headers.get(assistantTokenKey)
trans = await get_i18n(request)
#if assistantToken and assistantToken.lower().startswith("assistant "):
if assistantToken:
validator: tuple[any] = await self.validateAssistant(assistantToken, trans)
if validator[0]:
request.state.current_user = validator[1]
request.state.assistant = validator[2]
return await call_next(request)
message = trans('i18n_permission.authenticate_invalid', msg = validator[1])
return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"})
#validate pass
tokenkey = settings.TOKEN_KEY
token = request.headers.get(tokenkey)
validate_pass, data = await self.validateToken(token, trans)
if validate_pass:
request.state.current_user = data
return await call_next(request)
message = trans('i18n_permission.authenticate_invalid', msg = data)
return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"})
def is_options(self, request: Request):
return request.method == "OPTIONS"
async def validateToken(self, token: Optional[str], trans: I18n):
if not token:
return False, f"Miss Token[{settings.TOKEN_KEY}]!"
schema, param = get_authorization_scheme_param(token)
if schema.lower() != "bearer":
return False, f"Token schema error!"
try:
payload = jwt.decode(
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
)
token_data = TokenPayload(**payload)
with Session(engine) as session:
session_user = await get_user_info(session = session, user_id = token_data.id)
if not session_user:
message = trans('i18n_not_exist', msg = trans('i18n_user.account'))
raise Exception(message)
session_user = UserInfoDTO.model_validate(session_user)
if session_user.status != 1:
message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin'))
raise Exception(message)
if not session_user.oid or session_user.oid == 0:
message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin'))
raise Exception(message)
return True, session_user
except Exception as e:
msg = str(e)
SQLBotLogUtil.exception(f"Token validation error: {msg}")
if 'expired' in msg:
return False, jwt.ExpiredSignatureError(trans('i18n_permission.token_expired'))
return False, e
async def validateAssistant(self, assistantToken: Optional[str], trans: I18n) -> tuple[any]:
if not assistantToken:
return False, f"Miss Token[{settings.TOKEN_KEY}]!"
schema, param = get_authorization_scheme_param(assistantToken)
try:
if schema.lower() == 'embedded':
return await self.validateEmbedded(param, trans)
if schema.lower() != "assistant":
return False, f"Token schema error!"
payload = jwt.decode(
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
)
token_data = TokenPayload(**payload)
if not payload['assistant_id']:
return False, f"Miss assistant payload error!"
with Session(engine) as session:
""" session_user = await get_user_info(session = session, user_id = token_data.id)
session_user = UserInfoDTO.model_validate(session_user) """
session_user = get_assistant_user(id = token_data.id)
assistant_info = await get_assistant_info(session=session, assistant_id=payload['assistant_id'])
assistant_info = AssistantModel.model_validate(assistant_info)
assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True))
if assistant_info and assistant_info.type == 0:
if payload['oid']:
session_user.oid = int(payload['oid'])
else:
assistant_oid = 1
configuration = assistant_info.configuration
config_obj = json.loads(configuration) if configuration else {}
assistant_oid = config_obj.get('oid', 1)
session_user.oid = int(assistant_oid)
return True, session_user, assistant_info
except Exception as e:
SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}")
# Return False and the exception message
return False, e
async def validateEmbedded(self, param: str, trans: I18n) -> tuple[any]:
try:
""" payload = jwt.decode(
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
) """
payload: dict = jwt.decode(
param,
options={"verify_signature": False, "verify_exp": False},
algorithms=[security.ALGORITHM]
)
app_key = payload.get('appId', '')
embeddedId = payload.get('embeddedId', None)
if not embeddedId:
embeddedId = xor_decrypt(app_key)
if not payload['account']:
return False, f"Miss account payload error!"
account = payload['account']
with Session(engine) as session:
""" session_user = await get_user_info(session = session, user_id = token_data.id)
session_user = UserInfoDTO.model_validate(session_user) """
session_user = get_user_by_account(session = session, account=account)
if not session_user:
message = trans('i18n_not_exist', msg = trans('i18n_user.account'))
raise Exception(message)
session_user = await get_user_info(session = session, user_id = session_user.id)
session_user = UserInfoDTO.model_validate(session_user)
if session_user.status != 1:
message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin'))
raise Exception(message)
if not session_user.oid or session_user.oid == 0:
message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin'))
raise Exception(message)
assistant_info = await get_assistant_info(session=session, assistant_id=embeddedId)
assistant_info = AssistantModel.model_validate(assistant_info)
assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True))
return True, session_user, assistant_info
except Exception as e:
SQLBotLogUtil.exception(f"Embedded validation error: {str(e)}")
# Return False and the exception message
return False, e
def xor_decrypt(encrypted_str: str, key: int = 0xABCD1234) -> int:
encrypted_bytes = base64.urlsafe_b64decode(encrypted_str)
hex_str = encrypted_bytes.hex()
encrypted_num = int(hex_str, 16)
return encrypted_num ^ key