diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py new file mode 100644 index 0000000..6549740 --- /dev/null +++ b/backend/apps/db/db.py @@ -0,0 +1,403 @@ +import base64 +import json +import platform +import urllib.parse +from decimal import Decimal +from typing import Optional + +from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql +from common.error import ParseSQLResultError + +if platform.system() != "Darwin": + import dmPython +import pymysql +import redshift_connector +from sqlalchemy import create_engine, text, Engine +from sqlalchemy.orm import sessionmaker + +from apps.datasource.models.datasource import DatasourceConf, CoreDatasource, TableSchema, ColumnSchema +from apps.datasource.utils.utils import aes_decrypt +from apps.db.constant import DB, ConnectType +from apps.db.engine import get_engine_config +from apps.system.crud.assistant import get_ds_engine +from apps.system.schemas.system_schema import AssistantOutDsSchema +from common.core.deps import Trans +from common.utils.utils import SQLBotLogUtil +from fastapi import HTTPException + + +def get_uri(ds: CoreDatasource) -> str: + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + return get_uri_from_config(ds.type, conf) + + +def get_uri_from_config(type: str, conf: DatasourceConf) -> str: + db_url: str + if type == "mysql": + if conf.extraJdbc is not None and conf.extraJdbc != '': + db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" + else: + db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + elif type == "sqlServer": + if conf.extraJdbc is not None and conf.extraJdbc != '': + db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" + else: + db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + elif type == "pg" or type == "excel": + if conf.extraJdbc is not None and conf.extraJdbc != '': + db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" + else: + db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + elif type == "oracle": + if conf.mode == "service_name": + if conf.extraJdbc is not None and conf.extraJdbc != '': + db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}&{conf.extraJdbc}" + else: + db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={conf.database}" + else: + if conf.extraJdbc is not None and conf.extraJdbc != '': + db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" + else: + db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + elif type == "ck": + if conf.extraJdbc is not None and conf.extraJdbc != '': + db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" + else: + db_url = f"clickhouse+http://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + else: + raise 'The datasource type not support.' + return db_url + + +def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + if conf.timeout is None: + conf.timeout = timeout + if timeout > 0: + conf.timeout = timeout + if ds.type == "pg": + if conf.dbSchema is not None and conf.dbSchema != "": + engine = create_engine(get_uri(ds), + connect_args={"options": f"-c search_path={urllib.parse.quote(conf.dbSchema)}", + "connect_timeout": conf.timeout}, + pool_timeout=conf.timeout) + else: + engine = create_engine(get_uri(ds), + connect_args={"connect_timeout": conf.timeout}, + pool_timeout=conf.timeout) + elif ds.type == 'sqlServer': + engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) + elif ds.type == 'oracle': + engine = create_engine(get_uri(ds), + pool_timeout=conf.timeout) + else: # mysql, ck + engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout) + return engine + + +def get_session(ds: CoreDatasource | AssistantOutDsSchema): + engine = get_engine(ds) if isinstance(ds, CoreDatasource) else get_ds_engine(ds) + session_maker = sessionmaker(bind=engine) + session = session_maker() + return session + + +def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDsSchema, is_raise: bool = False): + if isinstance(ds, CoreDatasource): + db = DB.get_db(ds.type) + if db.connect_type == ConnectType.sqlalchemy: + conn = get_engine(ds, 10) + try: + with conn.connect() as connection: + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + else: + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1', timeout=10).fetchall() + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=10, + read_timeout=10) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=10) as conn, conn.cursor() as cursor: + try: + cursor.execute('select 1') + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + else: + conn = get_ds_engine(ds) + try: + with conn.connect() as connection: + SQLBotLogUtil.info("success") + return True + except Exception as e: + SQLBotLogUtil.error(f"Datasource {ds.id} connection failed: {e}") + if is_raise: + raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid') + f': {e.args}') + return False + + return False + + +def get_version(ds: CoreDatasource | AssistantOutDsSchema): + version = '' + conf = None + if isinstance(ds, CoreDatasource): + conf = DatasourceConf( + **json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + if isinstance(ds, AssistantOutDsSchema): + conf = DatasourceConf() + conf.host = ds.host + conf.port = ds.port + conf.username = ds.user + conf.password = ds.password + conf.database = ds.dataBase + conf.dbSchema = ds.db_schema + conf.timeout = 10 + db = DB.get_db(ds.type) + sql = get_version_sql(ds, conf) + try: + if db.connect_type == ConnectType.sqlalchemy: + with get_session(ds) as session: + with session.execute(text(sql)) as result: + res = result.fetchall() + version = res[0][0] + else: + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + cursor.execute(sql, timeout=10) + res = cursor.fetchall() + version = res[0][0] + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=10, + read_timeout=10) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + version = res[0][0] + elif ds.type == 'redshift': + version = '' + except Exception as e: + print(e) + version = '' + return version.decode() if isinstance(version, bytes) else version + + +def get_schema(ds: CoreDatasource): + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + db = DB.get_db(ds.type) + if db.connect_type == ConnectType.sqlalchemy: + with get_session(ds) as session: + sql: str = '' + if ds.type == "sqlServer": + sql = f"""select name from sys.schemas""" + elif ds.type == "pg" or ds.type == "excel": + sql = """SELECT nspname + FROM pg_namespace""" + elif ds.type == "oracle": + sql = f"""select * from all_users""" + with session.execute(text(sql)) as result: + res = result.fetchall() + res_list = [item[0] for item in res] + return res_list + else: + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + cursor.execute(f"""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout) + res = cursor.fetchall() + res_list = [item[0] for item in res] + return res_list + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout) as conn, conn.cursor() as cursor: + cursor.execute(f"""SELECT nspname FROM pg_namespace""") + res = cursor.fetchall() + res_list = [item[0] for item in res] + return res_list + + +def get_tables(ds: CoreDatasource): + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + db = DB.get_db(ds.type) + sql = get_table_sql(ds, conf) + if db.connect_type == ConnectType.sqlalchemy: + with get_session(ds) as session: + with session.execute(text(sql)) as result: + res = result.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + else: + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + cursor.execute(sql, timeout=conf.timeout) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=conf.timeout, + read_timeout=conf.timeout) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + res_list = [TableSchema(*item) for item in res] + return res_list + + +def get_fields(ds: CoreDatasource, table_name: str = None): + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() + db = DB.get_db(ds.type) + sql = get_field_sql(ds, conf, table_name) + if db.connect_type == ConnectType.sqlalchemy: + with get_session(ds) as session: + with session.execute(text(sql)) as result: + res = result.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + else: + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + cursor.execute(sql, timeout=conf.timeout) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=conf.timeout, + read_timeout=conf.timeout) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout) as conn, conn.cursor() as cursor: + cursor.execute(sql) + res = cursor.fetchall() + res_list = [ColumnSchema(*item) for item in res] + return res_list + + +def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False): + while sql.endswith(';'): + sql = sql[:-1] + + db = DB.get_db(ds.type) + if db.connect_type == ConnectType.sqlalchemy: + with get_session(ds) as session: + with session.execute(text(sql)) as result: + try: + columns = result.keys()._keys if origin_column else [item.lower() for item in result.keys()._keys] + res = result.fetchall() + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + else: + conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + if ds.type == 'dm': + with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, + port=conf.port) as conn, conn.cursor() as cursor: + try: + cursor.execute(sql, timeout=conf.timeout) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + elif ds.type == 'doris': + with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, + port=conf.port, db=conf.database, connect_timeout=conf.timeout, + read_timeout=conf.timeout) as conn, conn.cursor() as cursor: + try: + cursor.execute(sql) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex)) + elif ds.type == 'redshift': + with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, + password=conf.password, + timeout=conf.timeout) as conn, conn.cursor() as cursor: + try: + cursor.execute(sql) + res = cursor.fetchall() + columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for + field in + cursor.description] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple_item)} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + except Exception as ex: + raise ParseSQLResultError(str(ex))