Files
SQLBot/backend/apps/db/db.py

404 lines
21 KiB
Python
Raw Normal View History

2025-09-08 16:36:09 +08:00
import base64
import json
import platform
import urllib.parse
from decimal import Decimal
from typing import Optional
from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql
from common.error import ParseSQLResultError
if platform.system() != "Darwin":
import dmPython
import pymysql
import redshift_connector
from sqlalchemy import create_engine, text, Engine
from sqlalchemy.orm import sessionmaker
from apps.datasource.models.datasource import DatasourceConf, CoreDatasource, TableSchema, ColumnSchema
from apps.datasource.utils.utils import aes_decrypt
from apps.db.constant import DB, ConnectType
from apps.db.engine import get_engine_config
from apps.system.crud.assistant import get_ds_engine
from apps.system.schemas.system_schema import AssistantOutDsSchema
from common.core.deps import Trans
from common.utils.utils import SQLBotLogUtil
from fastapi import HTTPException
def get_uri(ds: CoreDatasource) -> str:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
return get_uri_from_config(ds.type, conf)
def get_uri_from_config(type: str, conf: DatasourceConf) -> str:
db_url: str
if type == "mysql":
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}"
else:
db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}"
elif type == "sqlServer":
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}"
else:
db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}"
elif type == "pg" or type == "excel":
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}"
else:
db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}"
elif type == "oracle":
if conf.mode == "service_name":
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}&{conf.extraJdbc}"
else:
db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}"
else:
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}"
else:
db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}"
elif type == "ck":
if conf.extraJdbc is not None and conf.extraJdbc != '':
db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}"
else:
db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}"
else:
raise 'The datasource type not support.'
return db_url
def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
if conf.timeout is None:
conf.timeout = timeout
if timeout > 0:
conf.timeout = timeout
if ds.type == "pg":
if conf.dbSchema is not None and conf.dbSchema != "":
engine = create_engine(get_uri(ds),
connect_args={"options": f"-c search_path={urllib.parse.quote(conf.dbSchema)}",
"connect_timeout": conf.timeout},
pool_timeout=conf.timeout)
else:
engine = create_engine(get_uri(ds),
connect_args={"connect_timeout": conf.timeout},
pool_timeout=conf.timeout)
elif ds.type == 'sqlServer':
engine = create_engine(get_uri(ds), pool_timeout=conf.timeout)
elif ds.type == 'oracle':
engine = create_engine(get_uri(ds),
pool_timeout=conf.timeout)
else: # mysql, ck
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout)
return engine
def get_session(ds: CoreDatasource | AssistantOutDsSchema):
engine = get_engine(ds) if isinstance(ds, CoreDatasource) else get_ds_engine(ds)
session_maker = sessionmaker(bind=engine)
session = session_maker()
return session
def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDsSchema, is_raise: bool = False):
if isinstance(ds, CoreDatasource):
db = DB.get_db(ds.type)
if db.connect_type == ConnectType.sqlalchemy:
conn = get_engine(ds, 10)
try:
with conn.connect() as connection:
SQLBotLogUtil.info("success")
return True
except Exception as e:
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
return False
else:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1', timeout=10).fetchall()
SQLBotLogUtil.info("success")
return True
except Exception as e:
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
return False
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
read_timeout=10) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1')
SQLBotLogUtil.info("success")
return True
except Exception as e:
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
return False
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=10) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1')
SQLBotLogUtil.info("success")
return True
except Exception as e:
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
return False
else:
conn = get_ds_engine(ds)
try:
with conn.connect() as connection:
SQLBotLogUtil.info("success")
return True
except Exception as e:
SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}")
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}')
return False
return False
def get_version(ds: CoreDatasource | AssistantOutDsSchema):
version = ''
conf = None
if isinstance(ds, CoreDatasource):
conf = DatasourceConf(
**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
if isinstance(ds, AssistantOutDsSchema):
conf = DatasourceConf()
conf.host = ds.host
conf.port = ds.port
conf.username = ds.user
conf.password = ds.password
conf.database = ds.dataBase
conf.dbSchema = ds.db_schema
conf.timeout = 10
db = DB.get_db(ds.type)
sql = get_version_sql(ds, conf)
try:
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
with session.execute(text(sql)) as result:
res = result.fetchall()
version = res[0][0]
else:
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
cursor.execute(sql, timeout=10)
res = cursor.fetchall()
version = res[0][0]
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
read_timeout=10) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
version = res[0][0]
elif ds.type == 'redshift':
version = ''
except Exception as e:
print(e)
version = ''
return version.decode() if isinstance(version, bytes) else version
def get_schema(ds: CoreDatasource):
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
db = DB.get_db(ds.type)
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
sql: str = ''
if ds.type == "sqlServer":
sql = f"""select name from sys.schemas"""
elif ds.type == "pg" or ds.type == "excel":
sql = """SELECT nspname
FROM pg_namespace"""
elif ds.type == "oracle":
sql = f"""select * from all_users"""
with session.execute(text(sql)) as result:
res = result.fetchall()
res_list = [item[0] for item in res]
return res_list
else:
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
cursor.execute(f"""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout)
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
cursor.execute(f"""SELECT nspname FROM pg_namespace""")
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
def get_tables(ds: CoreDatasource):
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
db = DB.get_db(ds.type)
sql = get_table_sql(ds, conf)
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
with session.execute(text(sql)) as result:
res = result.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
else:
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
cursor.execute(sql, timeout=conf.timeout)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
def get_fields(ds: CoreDatasource, table_name: str = None):
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
db = DB.get_db(ds.type)
sql = get_field_sql(ds, conf, table_name)
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
with session.execute(text(sql)) as result:
res = result.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
else:
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
cursor.execute(sql, timeout=conf.timeout)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False):
while sql.endswith(';'):
sql = sql[:-1]
db = DB.get_db(ds.type)
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
with session.execute(text(sql)) as result:
try:
columns = result.keys()._keys if origin_column else [item.lower() for item in result.keys()._keys]
res = result.fetchall()
result_list = [
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
enumerate(tuple_item)}
for tuple_item in res
]
return {"fields": columns, "data": result_list,
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
except Exception as ex:
raise ParseSQLResultError(str(ex))
else:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql, timeout=conf.timeout)
res = cursor.fetchall()
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
field in
cursor.description]
result_list = [
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
enumerate(tuple_item)}
for tuple_item in res
]
return {"fields": columns, "data": result_list,
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
except Exception as ex:
raise ParseSQLResultError(str(ex))
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
field in
cursor.description]
result_list = [
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
enumerate(tuple_item)}
for tuple_item in res
]
return {"fields": columns, "data": result_list,
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
except Exception as ex:
raise ParseSQLResultError(str(ex))
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
timeout=conf.timeout) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
field in
cursor.description]
result_list = [
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
enumerate(tuple_item)}
for tuple_item in res
]
return {"fields": columns, "data": result_list,
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
except Exception as ex:
raise ParseSQLResultError(str(ex))