Add File
This commit is contained in:
178
backend/apps/system/middleware/auth.py
Normal file
178
backend/apps/system/middleware/auth.py
Normal file
@@ -0,0 +1,178 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Optional
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
import jwt
|
||||
from sqlmodel import Session
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from apps.system.models.system_model import AssistantModel
|
||||
from common.core.db import engine
|
||||
from apps.system.crud.assistant import get_assistant_info, get_assistant_user
|
||||
from apps.system.crud.user import get_user_by_account, get_user_info
|
||||
from apps.system.schemas.system_schema import AssistantHeader, UserInfoDTO
|
||||
from common.core import security
|
||||
from common.core.config import settings
|
||||
from common.core.schemas import TokenPayload
|
||||
from common.utils.locale import I18n
|
||||
from common.utils.utils import SQLBotLogUtil
|
||||
from common.utils.whitelist import whiteUtils
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from common.core.deps import get_i18n
|
||||
class TokenMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
|
||||
if self.is_options(request) or whiteUtils.is_whitelisted(request.url.path):
|
||||
return await call_next(request)
|
||||
assistantTokenKey = settings.ASSISTANT_TOKEN_KEY
|
||||
assistantToken = request.headers.get(assistantTokenKey)
|
||||
trans = await get_i18n(request)
|
||||
#if assistantToken and assistantToken.lower().startswith("assistant "):
|
||||
if assistantToken:
|
||||
validator: tuple[any] = await self.validateAssistant(assistantToken, trans)
|
||||
if validator[0]:
|
||||
request.state.current_user = validator[1]
|
||||
request.state.assistant = validator[2]
|
||||
return await call_next(request)
|
||||
message = trans('i18n_permission.authenticate_invalid', msg = validator[1])
|
||||
return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"})
|
||||
#validate pass
|
||||
tokenkey = settings.TOKEN_KEY
|
||||
token = request.headers.get(tokenkey)
|
||||
validate_pass, data = await self.validateToken(token, trans)
|
||||
if validate_pass:
|
||||
request.state.current_user = data
|
||||
return await call_next(request)
|
||||
|
||||
message = trans('i18n_permission.authenticate_invalid', msg = data)
|
||||
return JSONResponse(message, status_code=401, headers={"Access-Control-Allow-Origin": "*"})
|
||||
|
||||
def is_options(self, request: Request):
|
||||
return request.method == "OPTIONS"
|
||||
|
||||
async def validateToken(self, token: Optional[str], trans: I18n):
|
||||
if not token:
|
||||
return False, f"Miss Token[{settings.TOKEN_KEY}]!"
|
||||
schema, param = get_authorization_scheme_param(token)
|
||||
if schema.lower() != "bearer":
|
||||
return False, f"Token schema error!"
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
with Session(engine) as session:
|
||||
session_user = await get_user_info(session = session, user_id = token_data.id)
|
||||
if not session_user:
|
||||
message = trans('i18n_not_exist', msg = trans('i18n_user.account'))
|
||||
raise Exception(message)
|
||||
session_user = UserInfoDTO.model_validate(session_user)
|
||||
if session_user.status != 1:
|
||||
message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin'))
|
||||
raise Exception(message)
|
||||
if not session_user.oid or session_user.oid == 0:
|
||||
message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin'))
|
||||
raise Exception(message)
|
||||
return True, session_user
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
SQLBotLogUtil.exception(f"Token validation error: {msg}")
|
||||
if 'expired' in msg:
|
||||
return False, jwt.ExpiredSignatureError(trans('i18n_permission.token_expired'))
|
||||
return False, e
|
||||
|
||||
|
||||
async def validateAssistant(self, assistantToken: Optional[str], trans: I18n) -> tuple[any]:
|
||||
if not assistantToken:
|
||||
return False, f"Miss Token[{settings.TOKEN_KEY}]!"
|
||||
schema, param = get_authorization_scheme_param(assistantToken)
|
||||
|
||||
|
||||
try:
|
||||
if schema.lower() == 'embedded':
|
||||
return await self.validateEmbedded(param, trans)
|
||||
if schema.lower() != "assistant":
|
||||
return False, f"Token schema error!"
|
||||
payload = jwt.decode(
|
||||
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
if not payload['assistant_id']:
|
||||
return False, f"Miss assistant payload error!"
|
||||
with Session(engine) as session:
|
||||
""" session_user = await get_user_info(session = session, user_id = token_data.id)
|
||||
session_user = UserInfoDTO.model_validate(session_user) """
|
||||
session_user = get_assistant_user(id = token_data.id)
|
||||
assistant_info = await get_assistant_info(session=session, assistant_id=payload['assistant_id'])
|
||||
assistant_info = AssistantModel.model_validate(assistant_info)
|
||||
assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True))
|
||||
if assistant_info and assistant_info.type == 0:
|
||||
if payload['oid']:
|
||||
session_user.oid = int(payload['oid'])
|
||||
else:
|
||||
assistant_oid = 1
|
||||
configuration = assistant_info.configuration
|
||||
config_obj = json.loads(configuration) if configuration else {}
|
||||
assistant_oid = config_obj.get('oid', 1)
|
||||
session_user.oid = int(assistant_oid)
|
||||
|
||||
return True, session_user, assistant_info
|
||||
except Exception as e:
|
||||
SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}")
|
||||
# Return False and the exception message
|
||||
return False, e
|
||||
|
||||
async def validateEmbedded(self, param: str, trans: I18n) -> tuple[any]:
|
||||
try:
|
||||
""" payload = jwt.decode(
|
||||
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
) """
|
||||
payload: dict = jwt.decode(
|
||||
param,
|
||||
options={"verify_signature": False, "verify_exp": False},
|
||||
algorithms=[security.ALGORITHM]
|
||||
)
|
||||
app_key = payload.get('appId', '')
|
||||
embeddedId = payload.get('embeddedId', None)
|
||||
if not embeddedId:
|
||||
embeddedId = xor_decrypt(app_key)
|
||||
if not payload['account']:
|
||||
return False, f"Miss account payload error!"
|
||||
account = payload['account']
|
||||
with Session(engine) as session:
|
||||
""" session_user = await get_user_info(session = session, user_id = token_data.id)
|
||||
session_user = UserInfoDTO.model_validate(session_user) """
|
||||
session_user = get_user_by_account(session = session, account=account)
|
||||
if not session_user:
|
||||
message = trans('i18n_not_exist', msg = trans('i18n_user.account'))
|
||||
raise Exception(message)
|
||||
session_user = await get_user_info(session = session, user_id = session_user.id)
|
||||
|
||||
session_user = UserInfoDTO.model_validate(session_user)
|
||||
if session_user.status != 1:
|
||||
message = trans('i18n_login.user_disable', msg = trans('i18n_concat_admin'))
|
||||
raise Exception(message)
|
||||
if not session_user.oid or session_user.oid == 0:
|
||||
message = trans('i18n_login.no_associated_ws', msg = trans('i18n_concat_admin'))
|
||||
raise Exception(message)
|
||||
assistant_info = await get_assistant_info(session=session, assistant_id=embeddedId)
|
||||
assistant_info = AssistantModel.model_validate(assistant_info)
|
||||
assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True))
|
||||
return True, session_user, assistant_info
|
||||
except Exception as e:
|
||||
SQLBotLogUtil.exception(f"Embedded validation error: {str(e)}")
|
||||
# Return False and the exception message
|
||||
return False, e
|
||||
|
||||
def xor_decrypt(encrypted_str: str, key: int = 0xABCD1234) -> int:
|
||||
encrypted_bytes = base64.urlsafe_b64decode(encrypted_str)
|
||||
hex_str = encrypted_bytes.hex()
|
||||
encrypted_num = int(hex_str, 16)
|
||||
return encrypted_num ^ key
|
||||
Reference in New Issue
Block a user