237 lines
9.7 KiB
Python
237 lines
9.7 KiB
Python
import json
|
|
import urllib
|
|
from typing import Optional
|
|
|
|
import requests
|
|
from fastapi import FastAPI
|
|
from sqlalchemy import Engine, create_engine
|
|
from sqlmodel import Session, select
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
|
|
from apps.system.models.system_model import AssistantModel
|
|
from apps.system.schemas.auth import CacheName, CacheNamespace
|
|
from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO
|
|
from common.core.config import settings
|
|
from common.core.db import engine
|
|
from common.core.sqlbot_cache import cache
|
|
from common.utils.aes_crypto import simple_aes_decrypt
|
|
from common.utils.utils import string_to_numeric_hash
|
|
|
|
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
|
|
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
|
|
db_model = session.get(AssistantModel, assistant_id)
|
|
return db_model
|
|
|
|
|
|
def get_assistant_user(*, id: int):
|
|
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant",
|
|
email="sqlbot-inner-assistant@sqlbot.com")
|
|
|
|
|
|
def get_assistant_ds(session: Session, llm_service) -> list[dict]:
|
|
assistant: AssistantHeader = llm_service.current_assistant
|
|
type = assistant.type
|
|
if type == 0 or type == 2:
|
|
configuration = assistant.configuration
|
|
if configuration:
|
|
config: dict[any] = json.loads(configuration)
|
|
oid: int = int(config['oid'])
|
|
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
|
|
CoreDatasource.oid == oid)
|
|
if not assistant.online:
|
|
public_list: list[int] = config.get('public_list') or None
|
|
if public_list:
|
|
stmt = stmt.where(CoreDatasource.id.in_(public_list))
|
|
else:
|
|
return []
|
|
""" private_list: list[int] = config.get('private_list') or None
|
|
if private_list:
|
|
stmt = stmt.where(~CoreDatasource.id.in_(private_list)) """
|
|
db_ds_list = session.exec(stmt)
|
|
|
|
result_list = [
|
|
{
|
|
"id": ds.id,
|
|
"name": ds.name,
|
|
"description": ds.description
|
|
}
|
|
for ds in db_ds_list
|
|
]
|
|
|
|
# filter private ds if offline
|
|
return result_list
|
|
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
|
|
llm_service.out_ds_instance = out_ds_instance
|
|
dslist = out_ds_instance.get_simple_ds_list()
|
|
# format?
|
|
return dslist
|
|
|
|
|
|
def init_dynamic_cors(app: FastAPI):
|
|
try:
|
|
with Session(engine) as session:
|
|
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
|
|
seen = set()
|
|
unique_domains = []
|
|
for item in list_result:
|
|
if item.domain:
|
|
for domain in item.domain.split(','):
|
|
domain = domain.strip()
|
|
if domain and domain not in seen:
|
|
seen.add(domain)
|
|
unique_domains.append(domain)
|
|
cors_middleware = None
|
|
for middleware in app.user_middleware:
|
|
if middleware.cls == CORSMiddleware:
|
|
cors_middleware = middleware
|
|
break
|
|
if cors_middleware:
|
|
updated_origins = list(set(settings.all_cors_origins + unique_domains))
|
|
cors_middleware.kwargs['allow_origins'] = updated_origins
|
|
except Exception as e:
|
|
return False, e
|
|
|
|
|
|
class AssistantOutDs:
|
|
assistant: AssistantHeader
|
|
ds_list: Optional[list[AssistantOutDsSchema]] = None
|
|
certificate: Optional[str] = None
|
|
|
|
def __init__(self, assistant: AssistantHeader):
|
|
self.assistant = assistant
|
|
self.ds_list = None
|
|
self.certificate = assistant.certificate
|
|
self.get_ds_from_api()
|
|
|
|
# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
|
|
def get_ds_from_api(self):
|
|
config: dict[any] = json.loads(self.assistant.configuration)
|
|
endpoint: str = config['endpoint']
|
|
certificateList: list[any] = json.loads(self.certificate)
|
|
header = {}
|
|
cookies = {}
|
|
param = {}
|
|
for item in certificateList:
|
|
if item['target'] == 'header':
|
|
header[item['key']] = item['value']
|
|
if item['target'] == 'cookie':
|
|
cookies[item['key']] = item['value']
|
|
if item['target'] == 'param':
|
|
param[item['key']] = item['value']
|
|
res = requests.get(url=endpoint, params=param, headers=header, cookies=cookies, timeout=10)
|
|
if res.status_code == 200:
|
|
result_json: dict[any] = json.loads(res.text)
|
|
if result_json.get('code') == 0 or result_json.get('code') == 200:
|
|
temp_list = result_json.get('data', [])
|
|
temp_ds_list = [
|
|
self.convert2schema(item, config)
|
|
for item in temp_list
|
|
]
|
|
self.ds_list = temp_ds_list
|
|
return self.ds_list
|
|
else:
|
|
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
|
|
else:
|
|
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")
|
|
|
|
def get_simple_ds_list(self):
|
|
if self.ds_list:
|
|
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
|
|
else:
|
|
raise Exception("Datasource list is not found.")
|
|
|
|
def get_db_schema(self, ds_id: int) -> str:
|
|
ds = self.get_ds(ds_id)
|
|
schema_str = ""
|
|
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
|
|
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
|
|
for table in ds.tables:
|
|
schema_str += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
|
|
table_comment = table.comment
|
|
if table_comment == '':
|
|
schema_str += '\n[\n'
|
|
else:
|
|
schema_str += f", {table_comment}\n[\n"
|
|
|
|
field_list = []
|
|
for field in table.fields:
|
|
field_comment = field.comment
|
|
if field_comment == '':
|
|
field_list.append(f"({field.name}:{field.type})")
|
|
else:
|
|
field_list.append(f"({field.name}:{field.type}, {field_comment})")
|
|
schema_str += ",\n".join(field_list)
|
|
schema_str += '\n]\n'
|
|
return schema_str
|
|
|
|
def get_ds(self, ds_id: int):
|
|
if self.ds_list:
|
|
for ds in self.ds_list:
|
|
if ds.id == ds_id:
|
|
return ds
|
|
else:
|
|
raise Exception("Datasource list is not found.")
|
|
raise Exception(f"Datasource with id {ds_id} not found.")
|
|
|
|
def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema:
|
|
id_marker: str = ''
|
|
attr_list = ['name', 'type', 'host', 'port', 'user', 'dataBase', 'schema']
|
|
if config.get('encrypt', False):
|
|
key = config.get('aes_key', None)
|
|
iv = config.get('aes_iv', None)
|
|
aes_attrs = ['host', 'user', 'password', 'dataBase', 'db_schema']
|
|
for attr in aes_attrs:
|
|
if attr in ds_dict and ds_dict[attr]:
|
|
try:
|
|
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
|
|
except Exception as e:
|
|
raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
|
|
for attr in attr_list:
|
|
if attr in ds_dict:
|
|
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
|
|
id = string_to_numeric_hash(id_marker)
|
|
db_schema = ds_dict.get('schema', ds_dict.get('db_schema', ''))
|
|
ds_dict.pop("schema", None)
|
|
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
|
|
|
|
class AssistantOutDsFactory:
|
|
@staticmethod
|
|
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
|
|
return AssistantOutDs(assistant)
|
|
|
|
|
|
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
|
|
timeout: int = 30
|
|
connect_args = {"connect_timeout": timeout}
|
|
conf = DatasourceConf(
|
|
host=ds.host,
|
|
port=ds.port,
|
|
username=ds.user,
|
|
password=ds.password,
|
|
database=ds.dataBase,
|
|
driver='',
|
|
extraJdbc=ds.extraParams,
|
|
dbSchema=ds.db_schema or ''
|
|
)
|
|
conf.extraJdbc = ''
|
|
from apps.db.db import get_uri_from_config
|
|
uri = get_uri_from_config(ds.type, conf)
|
|
# if ds.type == "pg" and ds.db_schema:
|
|
# connect_args.update({"options": f"-c search_path={ds.db_schema}"})
|
|
# engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
|
|
|
|
if ds.type == "pg" and ds.db_schema:
|
|
engine = create_engine(uri,
|
|
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
|
|
"connect_timeout": timeout},
|
|
pool_timeout=timeout)
|
|
elif ds.type == 'sqlServer':
|
|
engine = create_engine(uri, pool_timeout=timeout)
|
|
elif ds.type == 'oracle':
|
|
engine = create_engine(uri,
|
|
pool_timeout=timeout)
|
|
else:
|
|
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout)
|
|
return engine
|