374 lines
15 KiB
Python
374 lines
15 KiB
Python
import datetime
|
|
import json
|
|
from typing import List, Optional
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import and_, text
|
|
from sqlmodel import select
|
|
|
|
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
|
|
from apps.datasource.utils.utils import aes_decrypt
|
|
from apps.db.constant import DB
|
|
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
|
|
from apps.db.engine import get_engine_config, get_engine_conn
|
|
from apps.db.type import db_type_relation
|
|
from common.core.deps import SessionDep, CurrentUser, Trans
|
|
from common.utils.utils import deepcopy_ignore_extra
|
|
from .table import get_tables_by_ds_id
|
|
from ..crud.field import delete_field_by_ds_id, update_field
|
|
from ..crud.table import delete_table_by_ds_id, update_table
|
|
from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema, TableObj, \
|
|
DatasourceConf, TableAndFields
|
|
|
|
|
|
def get_datasource_list(session: SessionDep, user: CurrentUser, oid: Optional[int] = None) -> List[CoreDatasource]:
|
|
current_oid = user.oid if user.oid is not None else 1
|
|
if user.isAdmin and oid:
|
|
current_oid = oid
|
|
return session.exec(
|
|
select(CoreDatasource).where(CoreDatasource.oid == current_oid).order_by(CoreDatasource.name)).all()
|
|
|
|
|
|
def get_ds(session: SessionDep, id: int):
|
|
statement = select(CoreDatasource).where(CoreDatasource.id == id)
|
|
datasource = session.exec(statement).first()
|
|
return datasource
|
|
|
|
|
|
def check_status_by_id(session: SessionDep, trans: Trans, ds_id: int, is_raise: bool = False):
|
|
ds = session.get(CoreDatasource, ds_id)
|
|
if ds is None:
|
|
if is_raise:
|
|
raise HTTPException(status_code=500, detail=trans('i18n_ds_invalid'))
|
|
return False
|
|
return check_status(session, trans, ds, is_raise)
|
|
|
|
|
|
def check_status(session: SessionDep, trans: Trans, ds: CoreDatasource, is_raise: bool = False):
|
|
return check_connection(trans, ds, is_raise)
|
|
|
|
|
|
def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource):
|
|
if ds.id is not None:
|
|
ds_list = session.query(CoreDatasource).filter(
|
|
and_(CoreDatasource.name == ds.name, CoreDatasource.id != ds.id, CoreDatasource.oid == user.oid)).all()
|
|
if ds_list is not None and len(ds_list) > 0:
|
|
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))
|
|
else:
|
|
ds_list = session.query(CoreDatasource).filter(
|
|
and_(CoreDatasource.name == ds.name, CoreDatasource.oid == user.oid)).all()
|
|
if ds_list is not None and len(ds_list) > 0:
|
|
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))
|
|
|
|
|
|
def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource):
|
|
ds = CoreDatasource()
|
|
deepcopy_ignore_extra(create_ds, ds)
|
|
check_name(session, trans, user, ds)
|
|
ds.create_time = datetime.datetime.now()
|
|
# status = check_status(session, ds)
|
|
ds.create_by = user.id
|
|
ds.oid = user.oid if user.oid is not None else 1
|
|
ds.status = "Success"
|
|
ds.type_name = db_type_relation()[ds.type]
|
|
record = CoreDatasource(**ds.model_dump())
|
|
session.add(record)
|
|
session.flush()
|
|
session.refresh(record)
|
|
ds.id = record.id
|
|
session.commit()
|
|
|
|
# save tables and fields
|
|
sync_table(session, ds, create_ds.tables)
|
|
updateNum(session, ds)
|
|
return ds
|
|
|
|
|
|
def chooseTables(session: SessionDep, trans: Trans, id: int, tables: List[CoreTable]):
|
|
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
|
|
check_status(session, trans, ds, True)
|
|
sync_table(session, ds, tables)
|
|
updateNum(session, ds)
|
|
|
|
|
|
def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource):
|
|
ds.id = int(ds.id)
|
|
check_name(session, trans, user, ds)
|
|
# status = check_status(session, trans, ds)
|
|
ds.status = "Success"
|
|
record = session.exec(select(CoreDatasource).where(CoreDatasource.id == ds.id)).first()
|
|
update_data = ds.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(record, field, value)
|
|
session.add(record)
|
|
session.commit()
|
|
return ds
|
|
|
|
|
|
def delete_ds(session: SessionDep, id: int):
|
|
term = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
|
|
if term.type == "excel":
|
|
# drop all tables for current datasource
|
|
engine = get_engine_conn()
|
|
conf = DatasourceConf(**json.loads(aes_decrypt(term.configuration)))
|
|
with engine.connect() as conn:
|
|
for sheet in conf.sheets:
|
|
conn.execute(text(f'DROP TABLE IF EXISTS "{sheet["tableName"]}"'))
|
|
conn.commit()
|
|
|
|
session.delete(term)
|
|
session.commit()
|
|
delete_table_by_ds_id(session, id)
|
|
delete_field_by_ds_id(session, id)
|
|
return {
|
|
"message": f"Datasource with ID {id} deleted successfully."
|
|
}
|
|
|
|
|
|
def getTables(session: SessionDep, id: int):
|
|
ds = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
|
|
tables = get_tables(ds)
|
|
return tables
|
|
|
|
|
|
def getTablesByDs(session: SessionDep, ds: CoreDatasource):
|
|
# check_status(session, ds, True)
|
|
tables = get_tables(ds)
|
|
return tables
|
|
|
|
|
|
def getFields(session: SessionDep, id: int, table_name: str):
|
|
ds = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
|
|
fields = get_fields(ds, table_name)
|
|
return fields
|
|
|
|
|
|
def getFieldsByDs(session: SessionDep, ds: CoreDatasource, table_name: str):
|
|
fields = get_fields(ds, table_name)
|
|
return fields
|
|
|
|
|
|
def execSql(session: SessionDep, id: int, sql: str):
|
|
ds = session.exec(select(CoreDatasource).where(CoreDatasource.id == id)).first()
|
|
return exec_sql(ds, sql, True)
|
|
|
|
|
|
def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable]):
|
|
id_list = []
|
|
for item in tables:
|
|
statement = select(CoreTable).where(and_(CoreTable.ds_id == ds.id, CoreTable.table_name == item.table_name))
|
|
record = session.exec(statement).first()
|
|
# update exist table, only update table_comment
|
|
if record is not None:
|
|
item.id = record.id
|
|
id_list.append(record.id)
|
|
|
|
record.table_comment = item.table_comment
|
|
session.add(record)
|
|
session.commit()
|
|
else:
|
|
# save new table
|
|
table = CoreTable(ds_id=ds.id, checked=True, table_name=item.table_name, table_comment=item.table_comment,
|
|
custom_comment=item.table_comment)
|
|
session.add(table)
|
|
session.flush()
|
|
session.refresh(table)
|
|
item.id = table.id
|
|
id_list.append(table.id)
|
|
session.commit()
|
|
|
|
# sync field
|
|
fields = getFieldsByDs(session, ds, item.table_name)
|
|
sync_fields(session, ds, item, fields)
|
|
|
|
if len(id_list) > 0:
|
|
session.query(CoreTable).filter(and_(CoreTable.ds_id == ds.id, CoreTable.id.not_in(id_list))).delete(
|
|
synchronize_session=False)
|
|
session.query(CoreField).filter(and_(CoreField.ds_id == ds.id, CoreField.table_id.not_in(id_list))).delete(
|
|
synchronize_session=False)
|
|
session.commit()
|
|
else: # delete all tables and fields in this ds
|
|
session.query(CoreTable).filter(CoreTable.ds_id == ds.id).delete(synchronize_session=False)
|
|
session.query(CoreField).filter(CoreField.ds_id == ds.id).delete(synchronize_session=False)
|
|
session.commit()
|
|
|
|
|
|
def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]):
|
|
id_list = []
|
|
for index, item in enumerate(fields):
|
|
statement = select(CoreField).where(
|
|
and_(CoreField.table_id == table.id, CoreField.field_name == item.fieldName))
|
|
record = session.exec(statement).first()
|
|
if record is not None:
|
|
item.id = record.id
|
|
id_list.append(record.id)
|
|
|
|
record.field_comment = item.fieldComment
|
|
record.field_index = index
|
|
record.field_type = item.fieldType
|
|
session.add(record)
|
|
session.commit()
|
|
else:
|
|
field = CoreField(ds_id=ds.id, table_id=table.id, checked=True, field_name=item.fieldName,
|
|
field_type=item.fieldType, field_comment=item.fieldComment,
|
|
custom_comment=item.fieldComment, field_index=index)
|
|
session.add(field)
|
|
session.flush()
|
|
session.refresh(field)
|
|
item.id = field.id
|
|
id_list.append(field.id)
|
|
session.commit()
|
|
|
|
if len(id_list) > 0:
|
|
session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.id.not_in(id_list))).delete(
|
|
synchronize_session=False)
|
|
session.commit()
|
|
|
|
|
|
def update_table_and_fields(session: SessionDep, data: TableObj):
|
|
update_table(session, data.table)
|
|
for field in data.fields:
|
|
update_field(session, field)
|
|
|
|
|
|
def updateTable(session: SessionDep, table: CoreTable):
|
|
update_table(session, table)
|
|
|
|
|
|
def updateField(session: SessionDep, field: CoreField):
|
|
update_field(session, field)
|
|
|
|
|
|
def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):
|
|
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
|
|
# check_status(session, ds, True)
|
|
|
|
if data.fields is None or len(data.fields) == 0:
|
|
return {"fields": [], "data": [], "sql": ''}
|
|
|
|
where = ''
|
|
f_list = [f for f in data.fields if f.checked]
|
|
if is_normal_user(current_user):
|
|
# column is checked, and, column permission for data.fields
|
|
f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table,
|
|
fields=f_list)
|
|
|
|
# row permission tree
|
|
where_str = ''
|
|
filter_mapping = get_row_permission_filters(session=session, current_user=current_user, ds=ds, tables=None,
|
|
single_table=data.table)
|
|
if filter_mapping:
|
|
mapping_dict = filter_mapping[0]
|
|
where_str = mapping_dict.get('filter')
|
|
where = (' where ' + where_str) if where_str is not None and where_str != '' else ''
|
|
|
|
fields = [f.field_name for f in f_list]
|
|
if fields is None or len(fields) == 0:
|
|
return {"fields": [], "data": [], "sql": ''}
|
|
|
|
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
|
|
sql: str = ""
|
|
if ds.type == "mysql" or ds.type == "doris":
|
|
sql = f"""SELECT `{"`, `".join(fields)}` FROM `{data.table.table_name}`
|
|
{where}
|
|
LIMIT 100"""
|
|
elif ds.type == "sqlServer":
|
|
sql = f"""SELECT TOP 100 [{"], [".join(fields)}] FROM [{conf.dbSchema}].[{data.table.table_name}]
|
|
{where}
|
|
"""
|
|
elif ds.type == "pg" or ds.type == "excel" or ds.type == "redshift":
|
|
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
|
|
{where}
|
|
LIMIT 100"""
|
|
elif ds.type == "oracle":
|
|
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
|
|
{where}
|
|
ORDER BY "{fields[0]}"
|
|
OFFSET 0 ROWS FETCH NEXT 100 ROWS ONLY"""
|
|
elif ds.type == "ck":
|
|
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
|
|
{where}
|
|
LIMIT 100"""
|
|
elif ds.type == "dm":
|
|
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{conf.dbSchema}"."{data.table.table_name}"
|
|
{where}
|
|
LIMIT 100"""
|
|
return exec_sql(ds, sql, True)
|
|
|
|
|
|
def fieldEnum(session: SessionDep, id: int):
|
|
field = session.query(CoreField).filter(CoreField.id == id).first()
|
|
if field is None:
|
|
return []
|
|
table = session.query(CoreTable).filter(CoreTable.id == field.table_id).first()
|
|
if table is None:
|
|
return []
|
|
ds = session.query(CoreDatasource).filter(CoreDatasource.id == table.ds_id).first()
|
|
if ds is None:
|
|
return []
|
|
|
|
db = DB.get_db(ds.type)
|
|
sql = f"""SELECT DISTINCT {db.prefix}{field.field_name}{db.suffix} FROM {db.prefix}{table.table_name}{db.suffix}"""
|
|
res = exec_sql(ds, sql, True)
|
|
return [item.get(res.get('fields')[0]) for item in res.get('data')]
|
|
|
|
|
|
def updateNum(session: SessionDep, ds: CoreDatasource):
|
|
all_tables = get_tables(ds) if ds.type != 'excel' else json.loads(aes_decrypt(ds.configuration)).get('sheets')
|
|
selected_tables = get_tables_by_ds_id(session, ds.id)
|
|
num = f'{len(selected_tables)}/{len(all_tables)}'
|
|
|
|
record = session.exec(select(CoreDatasource).where(CoreDatasource.id == ds.id)).first()
|
|
update_data = ds.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(record, field, value)
|
|
record.num = num
|
|
session.add(record)
|
|
session.commit()
|
|
|
|
|
|
def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> List[TableAndFields]:
|
|
_list: List = []
|
|
tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all()
|
|
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
|
|
schema = conf.dbSchema if conf.dbSchema is not None and conf.dbSchema != "" else conf.database
|
|
for table in tables:
|
|
fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
|
|
|
|
# do column permissions, filter fields
|
|
fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields)
|
|
_list.append(TableAndFields(schema=schema, table=table, fields=fields))
|
|
return _list
|
|
|
|
|
|
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str:
|
|
schema_str = ""
|
|
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
|
|
if len(table_objs) == 0:
|
|
return schema_str
|
|
db_name = table_objs[0].schema
|
|
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
|
|
for obj in table_objs:
|
|
schema_str += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" else f"# Table: {obj.table.table_name}"
|
|
table_comment = ''
|
|
if obj.table.custom_comment:
|
|
table_comment = obj.table.custom_comment.strip()
|
|
if table_comment == '':
|
|
schema_str += '\n[\n'
|
|
else:
|
|
schema_str += f", {table_comment}\n[\n"
|
|
|
|
field_list = []
|
|
for field in obj.fields:
|
|
field_comment = ''
|
|
if field.custom_comment:
|
|
field_comment = field.custom_comment.strip()
|
|
if field_comment == '':
|
|
field_list.append(f"({field.field_name}:{field.field_type})")
|
|
else:
|
|
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
|
|
schema_str += ",\n".join(field_list)
|
|
schema_str += '\n]\n'
|
|
# todo 外键
|
|
return schema_str
|