Add File
This commit is contained in:
265
backend/common/utils/utils.py
Normal file
265
backend/common/utils/utils.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from common.core.config import settings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
import orjson
|
||||||
|
from jwt.exceptions import InvalidTokenError
|
||||||
|
|
||||||
|
from common.core import security
|
||||||
|
|
||||||
|
|
||||||
|
def generate_password_reset_token(email: str) -> str:
|
||||||
|
delta = timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expires = now + delta
|
||||||
|
exp = expires.timestamp()
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
{"exp": exp, "nbf": now, "sub": email},
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=security.ALGORITHM,
|
||||||
|
)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password_reset_token(token: str) -> str | None:
|
||||||
|
try:
|
||||||
|
decoded_token = jwt.decode(
|
||||||
|
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||||
|
)
|
||||||
|
return str(decoded_token["sub"])
|
||||||
|
except InvalidTokenError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def deepcopy_ignore_extra(src, dest):
|
||||||
|
import copy
|
||||||
|
for attr in vars(src):
|
||||||
|
if hasattr(dest, attr):
|
||||||
|
src_value = getattr(src, attr)
|
||||||
|
dest_value = copy.deepcopy(src_value) # deep copy
|
||||||
|
setattr(dest, attr, dest_value)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
def extract_nested_json(text):
|
||||||
|
stack = []
|
||||||
|
start_index = -1
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, char in enumerate(text):
|
||||||
|
if char in '{[':
|
||||||
|
if not stack: # 记录起始位置
|
||||||
|
start_index = i
|
||||||
|
stack.append(char)
|
||||||
|
elif char in '}]':
|
||||||
|
if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')):
|
||||||
|
stack.pop()
|
||||||
|
if not stack: # 栈空时截取完整JSON
|
||||||
|
json_str = text[start_index:i + 1]
|
||||||
|
try:
|
||||||
|
orjson.loads(json_str) # 验证有效性
|
||||||
|
results.append(json_str)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
stack = [] # 括号不匹配则重置
|
||||||
|
if len(results) > 0 and results[0]:
|
||||||
|
return results[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def string_to_numeric_hash(text: str, bits: Optional[int] = 64) -> int:
|
||||||
|
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||||
|
hash_num = int.from_bytes(hash_bytes, byteorder='big')
|
||||||
|
max_bigint = 2**63 - 1
|
||||||
|
return hash_num % max_bigint
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
# 确保日志目录存在
|
||||||
|
log_dir = Path(settings.LOG_DIR)
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 日志格式
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
f'{settings.LOG_FORMAT}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 控制台日志
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(settings.LOG_LEVEL)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# 文件日志处理器
|
||||||
|
file_handlers = {
|
||||||
|
'debug': logging.DEBUG,
|
||||||
|
'info': logging.INFO,
|
||||||
|
'warn': logging.WARNING,
|
||||||
|
'error': logging.ERROR
|
||||||
|
}
|
||||||
|
|
||||||
|
# 主日志记录器
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(logging.DEBUG) # 设置最低级别
|
||||||
|
|
||||||
|
# 添加控制台处理器
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 为每个级别创建文件处理器
|
||||||
|
for level_name, level in file_handlers.items():
|
||||||
|
file_path = log_dir / f"{level_name}.log"
|
||||||
|
handler = RotatingFileHandler(
|
||||||
|
file_path,
|
||||||
|
maxBytes=10 * 1024 * 1024, # 10MB
|
||||||
|
backupCount=5,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
handler.setLevel(level)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# 添加过滤器只处理特定级别日志
|
||||||
|
if level_name == 'debug':
|
||||||
|
handler.addFilter(lambda record: record.levelno == logging.DEBUG)
|
||||||
|
elif level_name == 'info':
|
||||||
|
handler.addFilter(lambda record: record.levelno == logging.INFO)
|
||||||
|
elif level_name == 'warn':
|
||||||
|
handler.addFilter(lambda record: record.levelno == logging.WARNING)
|
||||||
|
elif level_name == 'error':
|
||||||
|
handler.addFilter(lambda record: record.levelno >= logging.ERROR)
|
||||||
|
|
||||||
|
root_logger.addHandler(handler)
|
||||||
|
|
||||||
|
# SQL日志特殊处理
|
||||||
|
if settings.LOG_LEVEL == "DEBUG" and settings.SQL_DEBUG:
|
||||||
|
sql_logger = logging.getLogger('sqlalchemy.engine')
|
||||||
|
sql_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
sql_handler = RotatingFileHandler(
|
||||||
|
log_dir / "sql.log",
|
||||||
|
maxBytes=10 * 1024 * 1024,
|
||||||
|
backupCount=2,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
sql_handler.setFormatter(formatter)
|
||||||
|
sql_logger.addHandler(sql_handler)
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
class CallerLogger(logging.Logger):
|
||||||
|
def __init__(self, logger: logging.Logger):
|
||||||
|
self.logger = logger
|
||||||
|
super().__init__(logger.name, logger.level)
|
||||||
|
|
||||||
|
def _log(self, level, msg, args, exc_info=None, extra=None, stacklevel=3):
|
||||||
|
if self.logger.isEnabledFor(level):
|
||||||
|
self.logger._log(level, msg, args, exc_info=exc_info, extra=extra, stacklevel=stacklevel)
|
||||||
|
|
||||||
|
class SQLBotLogUtil:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_logger() -> logging.Logger:
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
try:
|
||||||
|
caller_frame = frame.f_back.f_back
|
||||||
|
module_name = caller_frame.f_globals.get('__name__', '__main__')
|
||||||
|
return CallerLogger(logging.getLogger(module_name))
|
||||||
|
finally:
|
||||||
|
del frame
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def debug(msg: str, *args, **kwargs):
|
||||||
|
logger = SQLBotLogUtil._get_logger()
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger._log(logging.DEBUG, msg, args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def info(msg: str, *args, **kwargs):
|
||||||
|
logger = SQLBotLogUtil._get_logger()
|
||||||
|
if logger.isEnabledFor(logging.INFO):
|
||||||
|
logger._log(logging.INFO, msg, args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def warning(msg: str, *args, **kwargs):
|
||||||
|
logger = SQLBotLogUtil._get_logger()
|
||||||
|
if logger.isEnabledFor(logging.WARNING):
|
||||||
|
logger._log(logging.WARNING, msg, args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def error(msg: str, *args, exc_info: Optional[bool] = None, **kwargs):
|
||||||
|
logger = SQLBotLogUtil._get_logger()
|
||||||
|
if logger.isEnabledFor(logging.ERROR):
|
||||||
|
logger._log(
|
||||||
|
logging.ERROR,
|
||||||
|
msg,
|
||||||
|
args,
|
||||||
|
exc_info=exc_info if exc_info is not None else True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exception(msg: str, *args, **kwargs):
|
||||||
|
logger = SQLBotLogUtil._get_logger()
|
||||||
|
if logger.isEnabledFor(logging.ERROR):
|
||||||
|
logger._log(logging.ERROR, msg, args, exc_info=True, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def critical(msg: str, *args, **kwargs):
|
||||||
|
logger = SQLBotLogUtil._get_logger()
|
||||||
|
if logger.isEnabledFor(logging.CRITICAL):
|
||||||
|
logger._log(logging.CRITICAL, msg, args, **kwargs)
|
||||||
|
|
||||||
|
def prepare_for_orjson(data):
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return base64.b64encode(data).decode('utf-8')
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return {k: prepare_for_orjson(v) for k, v in data.items()}
|
||||||
|
elif isinstance(data, (list, tuple)):
|
||||||
|
return [prepare_for_orjson(item) for item in data]
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_arg(origin_arg: str):
|
||||||
|
if not isinstance(origin_arg, str):
|
||||||
|
return origin_arg
|
||||||
|
if not origin_arg.strip()[0] in {'{', '['}:
|
||||||
|
return origin_arg
|
||||||
|
try:
|
||||||
|
return json.loads(origin_arg)
|
||||||
|
except:
|
||||||
|
return origin_arg
|
||||||
|
|
||||||
|
def get_origin_from_referer(request: Request):
|
||||||
|
referer = request.headers.get("referer")
|
||||||
|
if not referer:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(referer)
|
||||||
|
if not parsed.scheme or not parsed.hostname:
|
||||||
|
return None
|
||||||
|
port = parsed.port
|
||||||
|
if port:
|
||||||
|
if (parsed.scheme == "http" and port != 80) or \
|
||||||
|
(parsed.scheme == "https" and port != 443):
|
||||||
|
return f"{parsed.scheme}://{parsed.hostname}:{port}"
|
||||||
|
|
||||||
|
return f"{parsed.scheme}://{parsed.hostname}"
|
||||||
|
except Exception as e:
|
||||||
|
SQLBotLogUtil.error(f"解析 Referer 出错: {e}")
|
||||||
|
return referer
|
||||||
|
|
||||||
Reference in New Issue
Block a user