Files
SQLBot/backend/apps/datasource/crud/datasource.py
2025-09-08 16:36:22 +08:00

374 lines
15 KiB
Python

import datetime
import json
from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy import and_, text
from sqlmodel import select
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
from apps.datasource.utils.utils import aes_decrypt
from apps.db.constant import DB
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
from apps.db.engine import get_engine_config, get_engine_conn
from apps.db.type import db_type_relation
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.utils import deepcopy_ignore_extra
from .table import get_tables_by_ds_id
from ..crud.field import delete_field_by_ds_id, update_field
from ..crud.table import delete_table_by_ds_id, update_table
from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema, TableObj, \
DatasourceConf, TableAndFields
def get_datasource_list(session: SessionDep, user: CurrentUser, oid: Optional[int] = None) -> List[CoreDatasource]:
current_oid = user.oid if user.oid is not None else 1
if user.isAdmin and oid:
current_oid = oid
return session.exec(
select(CoreDatasource).where(CoreDatasource.oid == current_oid).order_by(CoreDatasource.name)).all()
def get_ds(session: SessionDep, id: int):
statement = select(CoreDatasource).where(CoreDatasource.id == id)
datasource = session.exec(statement).first()
return datasource
def check_status_by_id(session: SessionDep, trans: Trans, ds_id: int, is_raise: bool = False):
ds = session.get(CoreDatasource, ds_id)
if ds is None:
if is_raise:
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid'))
return False
return check_status(session, trans, ds, is_raise)
def check_status(session: SessionDep, trans: Trans, ds: CoreDatasource, is_raise: bool = False):
return check_connection(trans, ds, is_raise)
def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource):
if ds.id is not None:
ds_list = session.query(CoreDatasource).filter(
and_(CoreDatasource.name == ds.name, CoreDatasource.id != ds.id, CoreDatasource.oid == user.oid)).all()
if ds_list is not None and len(ds_list) > 0:
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))
else:
ds_list = session.query(CoreDatasource).filter(
and_(CoreDatasource.name == ds.name, CoreDatasource.oid == user.oid)).all()
if ds_list is not None and len(ds_list) > 0:
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))
def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource):
ds = CoreDatasource()
deepcopy_ignore_extra(create_ds, ds)
check_name(session, trans, user, ds)
ds.create_time = datetime.datetime.now()
# status = check_status(session, ds)
ds.create_by = user.id
ds.oid = user.oid if user.oid is not None else 1
ds.status = "Success"
ds.type_name = db_type_relation()[ds.type]
record = CoreDatasource(**ds.model_dump())
session.add(record)
session.flush()
session.refresh(record)
ds.id = record.id
session.commit()
# save tables and fields
sync_table(session, ds, create_ds.tables)
updateNum(session, ds)
return ds
def chooseTables(session: SessionDep, trans: Trans, id: int, tables: List[CoreTable]):
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
check_status(session, trans, ds, True)
sync_table(session, ds, tables)
updateNum(session, ds)
def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource):
ds.id = int(ds.id)
check_name(session, trans, user, ds)
# status = check_status(session, trans, ds)
ds.status = "Success"
record = session.exec(select(CoreDatasource).where(CoreDatasource.id == ds.id)).first()
update_data = ds.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(record, field, value)
session.add(record)
session.commit()
return ds
def delete_ds(session: SessionDep, id: int):
term = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
if term.type == "excel":
# drop all tables for current datasource
engine = get_engine_conn()
conf = DatasourceConf(**json.loads(aes_decrypt(term.configuration)))
with engine.connect() as conn:
for sheet in conf.sheets:
conn.execute(text(f'DROP TABLE IF EXISTS "{sheet["tableName"]}"'))
conn.commit()
session.delete(term)
session.commit()
delete_table_by_ds_id(session, id)
delete_field_by_ds_id(session, id)
return {
"message": f"Datasource with ID {id} deleted successfully."
}
def getTables(session: SessionDep, id: int):
ds = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
tables = get_tables(ds)
return tables
def getTablesByDs(session: SessionDep, ds: CoreDatasource):
# check_status(session, ds, True)
tables = get_tables(ds)
return tables
def getFields(session: SessionDep, id: int, table_name: str):
ds = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
fields = get_fields(ds, table_name)
return fields
def getFieldsByDs(session: SessionDep, ds: CoreDatasource, table_name: str):
fields = get_fields(ds, table_name)
return fields
def execSql(session: SessionDep, id: int, sql: str):
ds = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
return exec_sql(ds, sql, True)
def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable]):
id_list = []
for item in tables:
statement = select(CoreTable).where(and_(CoreTable.ds_id == ds.id, CoreTable.table_name == item.table_name))
record = session.exec(statement).first()
# update exist table, only update table_comment
if record is not None:
item.id = record.id
id_list.append(record.id)
record.table_comment = item.table_comment
session.add(record)
session.commit()
else:
# save new table
table = CoreTable(ds_id=ds.id, checked=True, table_name=item.table_name, table_comment=item.table_comment,
custom_comment=item.table_comment)
session.add(table)
session.flush()
session.refresh(table)
item.id = table.id
id_list.append(table.id)
session.commit()
# sync field
fields = getFieldsByDs(session, ds, item.table_name)
sync_fields(session, ds, item, fields)
if len(id_list) > 0:
session.query(CoreTable).filter(and_(CoreTable.ds_id == ds.id, CoreTable.id.not_in(id_list))).delete(
synchronize_session=False)
session.query(CoreField).filter(and_(CoreField.ds_id == ds.id, CoreField.table_id.not_in(id_list))).delete(
synchronize_session=False)
session.commit()
else: # delete all tables and fields in this ds
session.query(CoreTable).filter(CoreTable.ds_id == ds.id).delete(synchronize_session=False)
session.query(CoreField).filter(CoreField.ds_id == ds.id).delete(synchronize_session=False)
session.commit()
def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]):
id_list = []
for index, item in enumerate(fields):
statement = select(CoreField).where(
and_(CoreField.table_id == table.id, CoreField.field_name == item.fieldName))
record = session.exec(statement).first()
if record is not None:
item.id = record.id
id_list.append(record.id)
record.field_comment = item.fieldComment
record.field_index = index
record.field_type = item.fieldType
session.add(record)
session.commit()
else:
field = CoreField(ds_id=ds.id, table_id=table.id, checked=True, field_name=item.fieldName,
field_type=item.fieldType, field_comment=item.fieldComment,
custom_comment=item.fieldComment, field_index=index)
session.add(field)
session.flush()
session.refresh(field)
item.id = field.id
id_list.append(field.id)
session.commit()
if len(id_list) > 0:
session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.id.not_in(id_list))).delete(
synchronize_session=False)
session.commit()
def update_table_and_fields(session: SessionDep, data: TableObj):
update_table(session, data.table)
for field in data.fields:
update_field(session, field)
def updateTable(session: SessionDep, table: CoreTable):
update_table(session, table)
def updateField(session: SessionDep, field: CoreField):
update_field(session, field)
def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
# check_status(session, ds, True)
if data.fields is None or len(data.fields) == 0:
return {"fields": [], "data": [], "sql": ''}
where = ''
f_list = [f for f in data.fields if f.checked]
if is_normal_user(current_user):
# column is checked, and, column permission for data.fields
f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table,
fields=f_list)
# row permission tree
where_str = ''
filter_mapping = get_row_permission_filters(session=session, current_user=current_user, ds=ds, tables=None,
single_table=data.table)
if filter_mapping:
mapping_dict = filter_mapping[0]
where_str = mapping_dict.get('filter')
where = (' where ' + where_str) if where_str is not None and where_str != '' else ''
fields = [f.field_name for f in f_list]
if fields is None or len(fields) == 0:
return {"fields": [], "data": [], "sql": ''}
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
sql: str = ""
if ds.type == "mysql" or ds.type == "doris":
sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}`
{where}
LIMIT 100"""
elif ds.type == "sqlServer":
sql = f"""SELECT TOP 100 [{"], [".join(fields)}] FROM [{conf.dbSchema}].[{data.table.table_name}]
{where}
"""
elif ds.type == "pg" or ds.type == "excel" or ds.type == "redshift":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
{where}
LIMIT 100"""
elif ds.type == "oracle":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
{where}
ORDER BY "{fields[0]}"
OFFSET 0 ROWS FETCH NEXT 100 ROWS ONLY"""
elif ds.type == "ck":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
{where}
LIMIT 100"""
elif ds.type == "dm":
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
{where}
LIMIT 100"""
return exec_sql(ds, sql, True)
def fieldEnum(session: SessionDep, id: int):
field = session.query(CoreField).filter(CoreField.id == id).first()
if field is None:
return []
table = session.query(CoreTable).filter(CoreTable.id == field.table_id).first()
if table is None:
return []
ds = session.query(CoreDatasource).filter(CoreDatasource.id == table.ds_id).first()
if ds is None:
return []
db = DB.get_db(ds.type)
sql = f"""SELECT DISTINCT {db.prefix}{field.field_name}{db.suffix} FROM {db.prefix}{table.table_name}{db.suffix}"""
res = exec_sql(ds, sql, True)
return [item.get(res.get('fields')[0]) for item in res.get('data')]
def updateNum(session: SessionDep, ds: CoreDatasource):
all_tables = get_tables(ds) if ds.type != 'excel' else json.loads(aes_decrypt(ds.configuration)).get('sheets')
selected_tables = get_tables_by_ds_id(session, ds.id)
num = f'{len(selected_tables)}/{len(all_tables)}'
record = session.exec(select(CoreDatasource).where(CoreDatasource.id == ds.id)).first()
update_data = ds.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(record, field, value)
record.num = num
session.add(record)
session.commit()
def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> List[TableAndFields]:
_list: List = []
tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all()
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
schema = conf.dbSchema if conf.dbSchema is not None and conf.dbSchema != "" else conf.database
for table in tables:
fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
# do column permissions, filter fields
fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields)
_list.append(TableAndFields(schema=schema, table=table, fields=fields))
return _list
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
return schema_str
db_name = table_objs[0].schema
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
for obj in table_objs:
schema_str += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" else f"# Table: {obj.table.table_name}"
table_comment = ''
if obj.table.custom_comment:
table_comment = obj.table.custom_comment.strip()
if table_comment == '':
schema_str += '\n[\n'
else:
schema_str += f", {table_comment}\n[\n"
field_list = []
for field in obj.fields:
field_comment = ''
if field.custom_comment:
field_comment = field.custom_comment.strip()
if field_comment == '':
field_list.append(f"({field.field_name}:{field.field_type})")
else:
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
schema_str += ",\n".join(field_list)
schema_str += '\n]\n'
# todo 外键
return schema_str