diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py new file mode 100644 index 0000000..1587879 --- /dev/null +++ b/backend/apps/mcp/mcp.py @@ -0,0 +1,113 @@ +# Author: Junjun +# Date: 2025/7/1 + +from datetime import timedelta + +import jwt +from fastapi import HTTPException, status, APIRouter +from fastapi.responses import StreamingResponse +# from fastapi.security import OAuth2PasswordBearer +from jwt.exceptions import InvalidTokenError +from pydantic import ValidationError +from sqlmodel import select + +from apps.chat.api.chat import create_chat +from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion +from apps.chat.task.llm import LLMService +from apps.system.crud.user import authenticate +from apps.system.crud.user import get_db_user +from apps.system.models.system_model import UserWsModel +from apps.system.models.user import UserModel +from apps.system.schemas.system_schema import BaseUserDTO +from apps.system.schemas.system_schema import UserInfoDTO +from common.core import security +from common.core.config import settings +from common.core.deps import SessionDep +from common.core.schemas import TokenPayload, XOAuth2PasswordBearer, Token +from common.core.security import create_access_token + +reusable_oauth2 = XOAuth2PasswordBearer( + tokenUrl=f"{settings.API_V1_STR}/login/access-token" +) + +router = APIRouter(tags=["mcp"], prefix="/mcp") + + +# @router.post("/access_token", operation_id="access_token") +# def local_login( +# session: SessionDep, +# form_data: Annotated[OAuth2PasswordRequestForm, Depends()] +# ) -> Token: +# user = authenticate(session=session, account=form_data.username, password=form_data.password) +# if not user: +# raise HTTPException(status_code=400, detail="Incorrect account or password") +# access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) +# user_dict = user.to_dict() +# return Token(access_token=create_access_token( +# user_dict, expires_delta=access_token_expires +# )) + + +# @router.get("/ds_list", operation_id="get_datasource_list") +# async def datasource_list(session: SessionDep): +# return get_datasource_list(session=session) +# +# +# @router.get("/model_list", operation_id="get_model_list") +# async def get_model_list(session: SessionDep): +# return session.query(AiModelDetail).all() + + +@router.post("/mcp_start", operation_id="mcp_start") +async def mcp_start(session: SessionDep, chat: ChatStart): + user: BaseUserDTO = authenticate(session=session, account=chat.username, password=chat.password) + if not user: + raise HTTPException(status_code=400, detail="Incorrect account or password") + + if not user.oid or user.oid == 0: + raise HTTPException(status_code=400, detail="No associated workspace, Please contact the administrator") + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + user_dict = user.to_dict() + t = Token(access_token=create_access_token( + user_dict, expires_delta=access_token_expires + )) + c = create_chat(session, user, CreateChat(origin=1), False) + return {"access_token": t.access_token, "chat_id": c.id} + + +@router.post("/mcp_question", operation_id="mcp_question") +async def mcp_question(session: SessionDep, chat: McpQuestion): + try: + payload = jwt.decode( + chat.token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + except (InvalidTokenError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + # session_user = await get_user_info(session=session, user_id=token_data.id) + + db_user: UserModel = get_db_user(session=session, user_id=token_data.id) + session_user = UserInfoDTO.model_validate(db_user.model_dump()) + session_user.isAdmin = session_user.id == 1 and session_user.account == 'admin' + if session_user.isAdmin: + session_user = session_user + ws_model: UserWsModel = session.exec( + select(UserWsModel).where(UserWsModel.uid == session_user.id, UserWsModel.oid == session_user.oid)).first() + session_user.weight = ws_model.weight if ws_model else -1 + + session_user = UserInfoDTO.model_validate(session_user) + if not session_user: + raise HTTPException(status_code=404, detail="User not found") + + if session_user.status != 1: + raise HTTPException(status_code=400, detail="Inactive user") + + mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question) + # ask + llm_service = await LLMService.create(session_user, mcp_chat) + llm_service.init_record() + + return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream")