From 9ef03090622570728620149604493fd20d13676d Mon Sep 17 00:00:00 2001 From: inter Date: Mon, 8 Sep 2025 16:36:12 +0800 Subject: [PATCH] Add File --- backend/apps/system/middleware/auth.py | 178 +++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 backend/apps/system/middleware/auth.py diff --git a/backend/apps/system/middleware/auth.py b/backend/apps/system/middleware/auth.py new file mode 100644 index 0000000..3ea720c --- /dev/null +++ b/backend/apps/system/middleware/auth.py @@ -0,0 +1,178 @@ + +import base64 +import json +from typing import Optional +from fastapi import Request +from fastapi.responses import JSONResponse +import jwt +from sqlmodel import Session +from starlette.middleware.base import BaseHTTPMiddleware +from apps.system.models.system_model import AssistantModel +from common.core.db import engine +from apps.system.crud.assistant import get_assistant_info, get_assistant_user +from apps.system.crud.user import get_user_by_account, get_user_info +from apps.system.schemas.system_schema import AssistantHeader, UserInfoDTO +from common.core import security +from common.core.config import settings +from common.core.schemas import TokenPayload +from common.utils.locale import I18n +from common.utils.utils import SQLBotLogUtil +from common.utils.whitelist import whiteUtils +from fastapi.security.utils import get_authorization_scheme_param +from common.core.deps import get_i18n +class TokenMiddleware(BaseHTTPMiddleware): + + + + def __init__(self, app): + super().__init__(app) + + async def dispatch(self, request, call_next): + + if self.is_options(request) or whiteUtils.is_whitelisted(request.url.path): + return await call_next(request) + assistantTokenKey = settings.ASSISTANT_TOKEN_KEY + assistantToken = request.headers.get(assistantTokenKey) + trans = await get_i18n(request) + #if assistantToken and assistantToken.lower().startswith("assistant "): + if assistantToken: + validator: tuple[any] = await self.validateAssistant(assistantToken, trans) + if validator[0]: + request.state.current_user = validator[1] + request.state.assistant = validator[2] + return await call_next(request) + message = trans('i18n_permission.authenticate_invalid', msg = validator[1]) + return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"}) + #validate pass + tokenkey = settings.TOKEN_KEY + token = request.headers.get(tokenkey) + validate_pass, data = await self.validateToken(token, trans) + if validate_pass: + request.state.current_user = data + return await call_next(request) + + message = trans('i18n_permission.authenticate_invalid', msg = data) + return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"}) + + def is_options(self, request: Request): + return request.method == "OPTIONS" + + async def validateToken(self, token: Optional[str], trans: I18n): + if not token: + return False, f"Miss Token[{settings.TOKEN_KEY}]!" + schema, param = get_authorization_scheme_param(token) + if schema.lower() != "bearer": + return False, f"Token schema error!" + try: + payload = jwt.decode( + param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + with Session(engine) as session: + session_user = await get_user_info(session = session, user_id = token_data.id) + if not session_user: + message = trans('i18n_not_exist', msg = trans('i18n_user.account')) + raise Exception(message) + session_user = UserInfoDTO.model_validate(session_user) + if session_user.status != 1: + message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin')) + raise Exception(message) + if not session_user.oid or session_user.oid == 0: + message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin')) + raise Exception(message) + return True, session_user + except Exception as e: + msg = str(e) + SQLBotLogUtil.exception(f"Token validation error: {msg}") + if 'expired' in msg: + return False, jwt.ExpiredSignatureError(trans('i18n_permission.token_expired')) + return False, e + + + async def validateAssistant(self, assistantToken: Optional[str], trans: I18n) -> tuple[any]: + if not assistantToken: + return False, f"Miss Token[{settings.TOKEN_KEY}]!" + schema, param = get_authorization_scheme_param(assistantToken) + + + try: + if schema.lower() == 'embedded': + return await self.validateEmbedded(param, trans) + if schema.lower() != "assistant": + return False, f"Token schema error!" + payload = jwt.decode( + param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + if not payload['assistant_id']: + return False, f"Miss assistant payload error!" + with Session(engine) as session: + """ session_user = await get_user_info(session = session, user_id = token_data.id) + session_user = UserInfoDTO.model_validate(session_user) """ + session_user = get_assistant_user(id = token_data.id) + assistant_info = await get_assistant_info(session=session, assistant_id=payload['assistant_id']) + assistant_info = AssistantModel.model_validate(assistant_info) + assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True)) + if assistant_info and assistant_info.type == 0: + if payload['oid']: + session_user.oid = int(payload['oid']) + else: + assistant_oid = 1 + configuration = assistant_info.configuration + config_obj = json.loads(configuration) if configuration else {} + assistant_oid = config_obj.get('oid', 1) + session_user.oid = int(assistant_oid) + + return True, session_user, assistant_info + except Exception as e: + SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}") + # Return False and the exception message + return False, e + + async def validateEmbedded(self, param: str, trans: I18n) -> tuple[any]: + try: + """ payload = jwt.decode( + param, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) """ + payload: dict = jwt.decode( + param, + options={"verify_signature": False, "verify_exp": False}, + algorithms=[security.ALGORITHM] + ) + app_key = payload.get('appId', '') + embeddedId = payload.get('embeddedId', None) + if not embeddedId: + embeddedId = xor_decrypt(app_key) + if not payload['account']: + return False, f"Miss account payload error!" + account = payload['account'] + with Session(engine) as session: + """ session_user = await get_user_info(session = session, user_id = token_data.id) + session_user = UserInfoDTO.model_validate(session_user) """ + session_user = get_user_by_account(session = session, account=account) + if not session_user: + message = trans('i18n_not_exist', msg = trans('i18n_user.account')) + raise Exception(message) + session_user = await get_user_info(session = session, user_id = session_user.id) + + session_user = UserInfoDTO.model_validate(session_user) + if session_user.status != 1: + message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin')) + raise Exception(message) + if not session_user.oid or session_user.oid == 0: + message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin')) + raise Exception(message) + assistant_info = await get_assistant_info(session=session, assistant_id=embeddedId) + assistant_info = AssistantModel.model_validate(assistant_info) + assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True)) + return True, session_user, assistant_info + except Exception as e: + SQLBotLogUtil.exception(f"Embedded validation error: {str(e)}") + # Return False and the exception message + return False, e + +def xor_decrypt(encrypted_str: str, key: int = 0xABCD1234) -> int: + encrypted_bytes = base64.urlsafe_b64decode(encrypted_str) + hex_str = encrypted_bytes.hex() + encrypted_num = int(hex_str, 16) + return encrypted_num ^ key \ No newline at end of file