from dataclasses import dataclass from app.core.datasource.datasource import DataSourceBase from mysql import connector from mysql.connector import Error from sqlalchemy import create_engine import pandas as pd from configs import logger from utils import flat_map class MysqlDS(DataSourceBase): def __init__(self, host: str, port: int, username: str, password: str, database_name: str, use_ssl: int, type: str='mysql'): DataSourceBase.__init__(self, type, host, port, username, password, database_name, ) self.use_ssl = use_ssl @property def jdbc_url(self): jdbc = f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}' if self.use_ssl == 0: jdbc = f'{jdbc}?useSSL=false' return jdbc @property def jdbc_driver_class(self): return 'com.mysql.jdbc.Driver' @property def connection_str(self): return f'mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_name}' def is_connect(self): # 判断mysql是否连接成功 conn = None try: use_ssl = False if self.use_ssl == 0 else True conn = connector.connect(host=self.host, port=self.port, database=self.database_name, user=self.username, password=self.password, ssl_disabled=not use_ssl, connection_timeout=5) if conn.is_connected(): logger.info('Connected to MySQL database') except Error as e: logger.error(e) finally: if conn is not None and conn.is_connected(): conn.close() return True else: return False def _execute_sql(self, sqls): conn = None res = [] try: use_ssl = False if self.use_ssl == 0 else True conn = connector.connect(host=self.host, port=self.port, database=self.database_name, user=self.username, password=self.password, ssl_disabled=not use_ssl, connection_timeout=5) cursor = conn.cursor() for sql in sqls: cursor.execute(sql) res.append(cursor.fetchall()) logger.info(res) except Error as e: logger.error(e) raise Exception('mysql 连接失败') finally: if conn is not None and conn.is_connected(): conn.close() return res def get_preview_data(self, table_name, limit=100): # sql1 = f'SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="{self.database_name}" AND TABLE_NAME="{table_name}"' table_schema = self.get_table_schema(table_name) sql2 = f"SELECT * FROM {table_name} LIMIT {limit}" res = self._execute_sql([sql2]) logger.info(res) return { # 'header': flat_map(lambda x: x, res[0]), 'header': [str(column).split(':')[1] for column in table_schema], 'content': res[0] } # db_connection = create_engine(self.connection_str) # df = pd.read_sql(sql, con=db_connection) # # print(df) # logger.info(df.head()) # return df.to_numpy() def list_tables(self): # table_type = "base table" AND sql = f'SELECT table_name FROM information_schema.tables WHERE table_schema="{self.database_name}"' res = self._execute_sql([sql]) return flat_map(lambda x: x, res[0]) def get_table_schema(self, table_name): def handle_col(x): line = list(map(lambda s: s.decode('utf-8') if type(s) == type(b'bytes') else str(s), x)) return [':'.join(line[:3])] sql = f'describe `{self.database_name}`.{table_name}' logger.info(sql) res = self._execute_sql([sql]) if res: res = [[str(i) , *x]for i, x in enumerate(res[0])] logger.info(res) return flat_map(lambda x: handle_col(x), res) else: raise Exception('table not found')