import time from typing import List from sqlalchemy.orm import Session from sqlalchemy import func from app.core.datasource.datasource import DataSrouceFactory import app.schemas as schemas import app.models as models from app.utils import decode_user from app.utils.utils import decode_base64 def _decode(url, datasource, database_name): url = url.replace('jdbc:', '').replace('hive2://', '').replace(f'{datasource}://', '').replace(f'/{database_name}','') return url.split('?')[0] def _format_datasource(db: Session, item: schemas.JobJdbcDatasourceBase, ds_id: int = 0): if ds_id != 0: item = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first() if not item: raise Exception('未找到该数据源') item.jdbc_url = _decode(item.jdbc_url, item.datasource, item.database_name) item.jdbc_username, item.jdbc_password = decode_user(item.jdbc_username, item.jdbc_password) try: host, port = item.jdbc_url.split(':') except: raise Exception('数据库地址填写错误') if not host or not port: raise Exception('jdbc_url无效') ds = None if item.datasource == 'hive': ds = DataSrouceFactory.create(item.datasource, {'port': port, 'host': host, 'username': item.jdbc_username, 'password': item.jdbc_password, 'database_name': item.database_name, 'kerberos': item.kerberos, 'keytab': item.keytab, 'krb5config': item.krb5config, 'kerberos_service_name': item.kerberos_service_name, 'principal':item.principal}) else: ds = DataSrouceFactory.create(item.datasource, {'port': port, 'host': host, 'username': item.jdbc_username, 'password': item.jdbc_password, 'database_name': item.database_name, 'use_ssl': item.use_ssl }) item.jdbc_url = ds.jdbc_url item.jdbc_username = ds.jdbc_username if item.kerberos == 0 else None item.jdbc_password = ds.jdbc_password if item.kerberos == 0 else None return ds, item def test_datasource_connection(db: Session, item: schemas.JobJdbcDatasourceCreate): if item.jdbc_password and item.jdbc_password != '': item.jdbc_password = decode_base64(item.jdbc_password) ds, item = _format_datasource(db, item) return ds.is_connect() def get_table_schema(db: Session, ds_id: int, table_name: str): ds, item = _format_datasource(db, None, ds_id) return ds.get_table_schema(table_name) def get_preview_data(db: Session, ds_id: int, table_name: str, limit: int = 100): ds, item = _format_datasource(db, None, ds_id) return ds.get_preview_data(table_name, limit) def get_table_names(db: Session, ds_id: int): ds, item = _format_datasource(db, None, ds_id) return ds.list_tables() def create_job_jdbc_datasource(db: Session, item: schemas.JobJdbcDatasourceCreate): if item.jdbc_password and item.jdbc_password != '': item.jdbc_password = decode_base64(item.jdbc_password) ds, item = _format_datasource(db, item) con_result = ds.is_connect() if not con_result: raise Exception('连接失败,不允许添加') create_time: int = int(time.time()) name_item = db.query(models.JobJdbcDatasource)\ .filter(models.JobJdbcDatasource.datasource_name == func.binary(item.datasource_name))\ .filter(models.JobJdbcDatasource.status == 1).first() if name_item: raise Exception('数据源名称重复') db_item = models.JobJdbcDatasource(**item.dict(), **{ 'status': 1, 'create_time': create_time, 'create_by': 'admin', 'update_time': create_time, 'update_by': 'admin', 'jdbc_driver_class': ds.jdbc_driver_class }) db.add(db_item) db.commit() db.refresh(db_item) return db_item def get_job_jdbc_datasources(db: Session, datasource_type: str = None, skip: int = 0, limit: int = 20): res: List[models.JobJdbcDatasource] = [] if datasource_type is not None and datasource_type != '': res = db.query(models.JobJdbcDatasource)\ .filter(models.JobJdbcDatasource.datasource == datasource_type)\ .filter(models.JobJdbcDatasource.status == 1)\ .order_by(models.JobJdbcDatasource.create_time.desc()).all() else: res = db.query(models.JobJdbcDatasource)\ .filter(models.JobJdbcDatasource.status == 1)\ .order_by(models.JobJdbcDatasource.create_time.desc()).all() for item in res: item.jdbc_url = _decode(item.jdbc_url, item.datasource, item.database_name) return res def get_job_jdbc_datasources_info(db: Session, ds_id: int): db_item: models.JobJdbcDatasource = db.query(models.JobJdbcDatasource)\ .filter(models.JobJdbcDatasource.id == ds_id).first() db_item.jdbc_url = _decode(db_item.jdbc_url, db_item.datasource, db_item.database_name) if db_item.jdbc_username and db_item.jdbc_username != '': db_item.jdbc_username = decode_base64(db_item.jdbc_username) return db_item def update_job_jdbc_datasources(db: Session, ds_id: int, update_item: schemas.JobJdbcDatasourceUpdate): if update_item.jdbc_password and update_item.jdbc_password != '': update_item.jdbc_password = decode_base64(update_item.jdbc_password) ds, update_item = _format_datasource(db, update_item) con_result = ds.is_connect() if not con_result: raise Exception('连接失败,不允许添加') db_item = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first() if not db_item: raise Exception('未找到该数据源') name_item = db.query(models.JobJdbcDatasource)\ .filter(models.JobJdbcDatasource.datasource_name == func.binary(update_item.datasource_name))\ .filter(models.JobJdbcDatasource.status == 1)\ .filter(models.JobJdbcDatasource.id != ds_id).first() if name_item: raise Exception('数据源名称重复') update_dict = update_item.dict(exclude_unset=True) for k, v in update_dict.items(): setattr(db_item, k, v) db_item.jdbc_driver_class = ds.jdbc_driver_class db_item.update_time = int(time.time()) db_item.update_by = 'admin1' # TODO db.commit() db.flush() db.refresh(db_item) return db_item def delete_job_jdbc_datasource(db: Session, ds_id: int): db_item = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first() if not db_item: raise Exception('未找到该数据源') db_item.status = 0 db.commit() db.flush() db.refresh(db_item) return db_item def get_job_jdbc_datasource(db: Session, ds_id: int): db_item: models.JobJdbcDatasource = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first() if not db_item: raise Exception('未找到该数据源') return db_item def get_job_jdbc_datasource_table_location(db: Session, db_name: str, table_name: str, ds_id: int): ds, item = _format_datasource(db, None, ds_id) return ds.get_table_info(table_name,db_name)