Files
SQLBot/backend/apps/system/crud/assistant.py
2025-09-08 16:36:15 +08:00

237 lines
9.7 KiB
Python

import json
import urllib
from typing import Optional
import requests
from fastapi import FastAPI
from sqlalchemy import Engine, create_engine
from sqlmodel import Session, select
from starlette.middleware.cors import CORSMiddleware
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
from apps.system.models.system_model import AssistantModel
from apps.system.schemas.auth import CacheName, CacheNamespace
from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO
from common.core.config import settings
from common.core.db import engine
from common.core.sqlbot_cache import cache
from common.utils.aes_crypto import simple_aes_decrypt
from common.utils.utils import string_to_numeric_hash
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
db_model = session.get(AssistantModel, assistant_id)
return db_model
def get_assistant_user(*, id: int):
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant",
email="sqlbot-inner-assistant@sqlbot.com")
def get_assistant_ds(session: Session, llm_service) -> list[dict]:
assistant: AssistantHeader = llm_service.current_assistant
type = assistant.type
if type == 0 or type == 2:
configuration = assistant.configuration
if configuration:
config: dict[any] = json.loads(configuration)
oid: int = int(config['oid'])
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
CoreDatasource.oid == oid)
if not assistant.online:
public_list: list[int] = config.get('public_list') or None
if public_list:
stmt = stmt.where(CoreDatasource.id.in_(public_list))
else:
return []
""" private_list: list[int] = config.get('private_list') or None
if private_list:
stmt = stmt.where(~CoreDatasource.id.in_(private_list)) """
db_ds_list = session.exec(stmt)
result_list = [
{
"id": ds.id,
"name": ds.name,
"description": ds.description
}
for ds in db_ds_list
]
# filter private ds if offline
return result_list
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
llm_service.out_ds_instance = out_ds_instance
dslist = out_ds_instance.get_simple_ds_list()
# format?
return dslist
def init_dynamic_cors(app: FastAPI):
try:
with Session(engine) as session:
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
seen = set()
unique_domains = []
for item in list_result:
if item.domain:
for domain in item.domain.split(','):
domain = domain.strip()
if domain and domain not in seen:
seen.add(domain)
unique_domains.append(domain)
cors_middleware = None
for middleware in app.user_middleware:
if middleware.cls == CORSMiddleware:
cors_middleware = middleware
break
if cors_middleware:
updated_origins = list(set(settings.all_cors_origins + unique_domains))
cors_middleware.kwargs['allow_origins'] = updated_origins
except Exception as e:
return False, e
class AssistantOutDs:
assistant: AssistantHeader
ds_list: Optional[list[AssistantOutDsSchema]] = None
certificate: Optional[str] = None
def __init__(self, assistant: AssistantHeader):
self.assistant = assistant
self.ds_list = None
self.certificate = assistant.certificate
self.get_ds_from_api()
# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
def get_ds_from_api(self):
config: dict[any] = json.loads(self.assistant.configuration)
endpoint: str = config['endpoint']
certificateList: list[any] = json.loads(self.certificate)
header = {}
cookies = {}
param = {}
for item in certificateList:
if item['target'] == 'header':
header[item['key']] = item['value']
if item['target'] == 'cookie':
cookies[item['key']] = item['value']
if item['target'] == 'param':
param[item['key']] = item['value']
res = requests.get(url=endpoint, params=param, headers=header, cookies=cookies, timeout=10)
if res.status_code == 200:
result_json: dict[any] = json.loads(res.text)
if result_json.get('code') == 0 or result_json.get('code') == 200:
temp_list = result_json.get('data', [])
temp_ds_list = [
self.convert2schema(item, config)
for item in temp_list
]
self.ds_list = temp_ds_list
return self.ds_list
else:
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
else:
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")
def get_simple_ds_list(self):
if self.ds_list:
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
else:
raise Exception("Datasource list is not found.")
def get_db_schema(self, ds_id: int) -> str:
ds = self.get_ds(ds_id)
schema_str = ""
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
for table in ds.tables:
schema_str += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
table_comment = table.comment
if table_comment == '':
schema_str += '\n[\n'
else:
schema_str += f", {table_comment}\n[\n"
field_list = []
for field in table.fields:
field_comment = field.comment
if field_comment == '':
field_list.append(f"({field.name}:{field.type})")
else:
field_list.append(f"({field.name}:{field.type}, {field_comment})")
schema_str += ",\n".join(field_list)
schema_str += '\n]\n'
return schema_str
def get_ds(self, ds_id: int):
if self.ds_list:
for ds in self.ds_list:
if ds.id == ds_id:
return ds
else:
raise Exception("Datasource list is not found.")
raise Exception(f"Datasource with id {ds_id} not found.")
def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
id_marker: str = ''
attr_list = ['name', 'type', 'host', 'port', 'user', 'dataBase', 'schema']
if config.get('encrypt', False):
key = config.get('aes_key', None)
iv = config.get('aes_iv', None)
aes_attrs = ['host', 'user', 'password', 'dataBase', 'db_schema']
for attr in aes_attrs:
if attr in ds_dict and ds_dict[attr]:
try:
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
except Exception as e:
raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
for attr in attr_list:
if attr in ds_dict:
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
id = string_to_numeric_hash(id_marker)
db_schema = ds_dict.get('schema', ds_dict.get('db_schema', ''))
ds_dict.pop("schema", None)
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
class AssistantOutDsFactory:
@staticmethod
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
return AssistantOutDs(assistant)
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
timeout: int = 30
connect_args = {"connect_timeout": timeout}
conf = DatasourceConf(
host=ds.host,
port=ds.port,
username=ds.user,
password=ds.password,
database=ds.dataBase,
driver='',
extraJdbc=ds.extraParams,
dbSchema=ds.db_schema or ''
)
conf.extraJdbc = ''
from apps.db.db import get_uri_from_config
uri = get_uri_from_config(ds.type, conf)
# if ds.type == "pg" and ds.db_schema:
# connect_args.update({"options": f"-c search_path={ds.db_schema}"})
# engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
if ds.type == "pg" and ds.db_schema:
engine = create_engine(uri,
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
"connect_timeout": timeout},
pool_timeout=timeout)
elif ds.type == 'sqlServer':
engine = create_engine(uri, pool_timeout=timeout)
elif ds.type == 'oracle':
engine = create_engine(uri,
pool_timeout=timeout)
else:
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout)
return engine