Add File
This commit is contained in:
113
backend/apps/mcp/mcp.py
Normal file
113
backend/apps/mcp/mcp.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Author: Junjun
|
||||
# Date: 2025/7/1
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, status, APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
# from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import select
|
||||
|
||||
from apps.chat.api.chat import create_chat
|
||||
from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion
|
||||
from apps.chat.task.llm import LLMService
|
||||
from apps.system.crud.user import authenticate
|
||||
from apps.system.crud.user import get_db_user
|
||||
from apps.system.models.system_model import UserWsModel
|
||||
from apps.system.models.user import UserModel
|
||||
from apps.system.schemas.system_schema import BaseUserDTO
|
||||
from apps.system.schemas.system_schema import UserInfoDTO
|
||||
from common.core import security
|
||||
from common.core.config import settings
|
||||
from common.core.deps import SessionDep
|
||||
from common.core.schemas import TokenPayload, XOAuth2PasswordBearer, Token
|
||||
from common.core.security import create_access_token
|
||||
|
||||
reusable_oauth2 = XOAuth2PasswordBearer(
|
||||
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["mcp"], prefix="/mcp")
|
||||
|
||||
|
||||
# @router.post("/access_token", operation_id="access_token")
|
||||
# def local_login(
|
||||
# session: SessionDep,
|
||||
# form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
|
||||
# ) -> Token:
|
||||
# user = authenticate(session=session, account=form_data.username, password=form_data.password)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=400, detail="Incorrect account or password")
|
||||
# access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
# user_dict = user.to_dict()
|
||||
# return Token(access_token=create_access_token(
|
||||
# user_dict, expires_delta=access_token_expires
|
||||
# ))
|
||||
|
||||
|
||||
# @router.get("/ds_list", operation_id="get_datasource_list")
|
||||
# async def datasource_list(session: SessionDep):
|
||||
# return get_datasource_list(session=session)
|
||||
#
|
||||
#
|
||||
# @router.get("/model_list", operation_id="get_model_list")
|
||||
# async def get_model_list(session: SessionDep):
|
||||
# return session.query(AiModelDetail).all()
|
||||
|
||||
|
||||
@router.post("/mcp_start", operation_id="mcp_start")
|
||||
async def mcp_start(session: SessionDep, chat: ChatStart):
|
||||
user: BaseUserDTO = authenticate(session=session, account=chat.username, password=chat.password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="Incorrect account or password")
|
||||
|
||||
if not user.oid or user.oid == 0:
|
||||
raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator")
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
user_dict = user.to_dict()
|
||||
t = Token(access_token=create_access_token(
|
||||
user_dict, expires_delta=access_token_expires
|
||||
))
|
||||
c = create_chat(session, user, CreateChat(origin=1), False)
|
||||
return {"access_token": t.access_token, "chat_id": c.id}
|
||||
|
||||
|
||||
@router.post("/mcp_question", operation_id="mcp_question")
|
||||
async def mcp_question(session: SessionDep, chat: McpQuestion):
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
chat.token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
)
|
||||
token_data = TokenPayload(**payload)
|
||||
except (InvalidTokenError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
# session_user = await get_user_info(session=session, user_id=token_data.id)
|
||||
|
||||
db_user: UserModel = get_db_user(session=session, user_id=token_data.id)
|
||||
session_user = UserInfoDTO.model_validate(db_user.model_dump())
|
||||
session_user.isAdmin = session_user.id == 1 and session_user.account == 'admin'
|
||||
if session_user.isAdmin:
|
||||
session_user = session_user
|
||||
ws_model: UserWsModel = session.exec(
|
||||
select(UserWsModel).where(UserWsModel.uid == session_user.id, UserWsModel.oid == session_user.oid)).first()
|
||||
session_user.weight = ws_model.weight if ws_model else -1
|
||||
|
||||
session_user = UserInfoDTO.model_validate(session_user)
|
||||
if not session_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if session_user.status != 1:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question)
|
||||
# ask
|
||||
llm_service = await LLMService.create(session_user, mcp_chat)
|
||||
llm_service.init_record()
|
||||
|
||||
return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream")
|
||||
Reference in New Issue
Block a user