12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- from dataclasses import dataclass
- from app.core.datasource import DataSourceBase
- from mysql import connector
- from mysql.connector import Error
- from sqlalchemy import create_engine
- import pandas as pd
- from configs import logger
- from utils import flat_map
- @dataclass
- class MysqlDS(DataSourceBase):
- type = 'mysql'
- @property
- def jdbc_url(self):
- return f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}'
- @property
- def jdbc_driver_class(self):
- return 'com.mysql.jdbc.Driver'
- @property
- def connection_str(self):
- return f'mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_name}'
- def is_connect(self):
- # 判断mysql是否连接成功
- conn = None
- try:
- conn = connector.connect(host=self.host,
- port=self.port,
- database=self.database_name,
- user=self.username,
- password=self.password)
- if conn.is_connected():
- logger.info('Connected to MySQL database')
- except Error as e:
- logger.error(e)
- finally:
- if conn is not None and conn.is_connected():
- conn.close()
- return True
- else:
- return False
- def _execute_sql(self, sqls):
- conn = None
- res = []
- try:
- conn = connector.connect(host=self.host,
- port=self.port,
- database=self.database_name,
- user=self.username,
- password=self.password)
- cursor = conn.cursor()
- for sql in sqls:
- cursor.execute(sql)
- res.append(cursor.fetchall())
- logger.info(res)
- except Error as e:
- logger.error(e)
- finally:
- if conn is not None and conn.is_connected():
- conn.close()
- return res
- def get_preview_data(self, table_name, limit=100):
- sql1 = f'SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="{self.database_name}" AND TABLE_NAME="{table_name}"';
- sql2 = f"SELECT * FROM {table_name} LIMIT {limit}"
- res = self._execute_sql([sql1, sql2])
- logger.info(res)
- return {
- 'header': flat_map(lambda x: x, res[0]),
- 'content': res[1]
- }
- # db_connection = create_engine(self.connection_str)
- # df = pd.read_sql(sql, con=db_connection)
- # # print(df)
- # logger.info(df.head())
- # return df.to_numpy()
- def list_tables(self):
- sql = f'SELECT table_name FROM information_schema.tables WHERE table_type = "base table" AND table_schema="{self.database_name}"'
- res = self._execute_sql([sql])
- return flat_map(lambda x: x, res[0])
|