from dataclasses import dataclass from app.core.ds.datasource import DataSourceBase from mysql import connector from mysql.connector import Error from sqlalchemy import create_engine import pandas as pd from configs import logger from utils import flat_map @dataclass class MysqlDS(DataSourceBase): type = 'mysql' @property def jdbc_url(self): return f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}' @property def jdbc_driver_class(self): return 'com.mysql.jdbc.Driver' @property def connection_str(self): return f'mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_name}' def is_connect(self): # 判断mysql是否连接成功 conn = None try: conn = connector.connect(host=self.host, port=self.port, database=self.database_name, user=self.username, password=self.password) if conn.is_connected(): logger.info('Connected to MySQL database') except Error as e: logger.error(e) finally: if conn is not None and conn.is_connected(): conn.close() return True else: return False def _execute_sql(self, sqls): conn = None res = [] try: conn = connector.connect(host=self.host, port=self.port, database=self.database_name, user=self.username, password=self.password) cursor = conn.cursor() for sql in sqls: cursor.execute(sql) res.append(cursor.fetchall()) logger.info(res) except Error as e: logger.error(e) finally: if conn is not None and conn.is_connected(): conn.close() return res def get_preview_data(self, table_name, limit=100): sql1 = f'SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="{self.database_name}" AND TABLE_NAME="{table_name}"' sql2 = f"SELECT * FROM {table_name} LIMIT {limit}" res = self._execute_sql([sql1, sql2]) logger.info(res) return { 'header': flat_map(lambda x: x, res[0]), 'content': res[1] } # db_connection = create_engine(self.connection_str) # df = pd.read_sql(sql, con=db_connection) # # print(df) # logger.info(df.head()) # return df.to_numpy() def list_tables(self): sql = f'SELECT table_name FROM information_schema.tables WHERE table_type = "base table" AND table_schema="{self.database_name}"' res = self._execute_sql([sql]) return flat_map(lambda x: x, res[0]) def get_table_schema(self, table_name): sql = f'describe {self.database_name}.{table_name}' res = self._execute_sql([sql]) if res: res = [[str(i) , *x]for i, x in enumerate(res[0])] logger.info(res) return flat_map(lambda x: [':'.join(x[:3])], res) else: raise Exception('table not found')