123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- from dataclasses import dataclass
- from app.core.datasource.datasource import DataSourceBase
- from mysql import connector
- from mysql.connector import Error
- from sqlalchemy import create_engine
- import pandas as pd
- from configs import logger
- from utils import flat_map
- class MysqlDS(DataSourceBase):
- def __init__(self, host: str, port: int, username: str, password: str, database_name: str, use_ssl: int, type: str='mysql'):
- DataSourceBase.__init__(self, type, host, port, username, password, database_name, )
- self.use_ssl = use_ssl
- @property
- def jdbc_url(self):
- jdbc = f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}'
- if self.use_ssl == 0:
- jdbc = f'{jdbc}?useSSL=false'
- return jdbc
- @property
- def jdbc_driver_class(self):
- return 'com.mysql.jdbc.Driver'
- @property
- def connection_str(self):
- return f'mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_name}'
- def is_connect(self):
- # 判断mysql是否连接成功
- conn = None
- try:
- use_ssl = False if self.use_ssl == 0 else True
- conn = connector.connect(host=self.host,
- port=self.port,
- database=self.database_name,
- user=self.username,
- password=self.password,
- ssl_disabled=not use_ssl,
- connection_timeout=5)
- if conn.is_connected():
- logger.info('Connected to MySQL database')
- except Error as e:
- logger.error(e)
- finally:
- if conn is not None and conn.is_connected():
- conn.close()
- return True
- else:
- return False
- def _execute_sql(self, sqls):
- conn = None
- res = []
- try:
- use_ssl = False if self.use_ssl == 0 else True
- conn = connector.connect(host=self.host,
- port=self.port,
- database=self.database_name,
- user=self.username,
- password=self.password,
- ssl_disabled=not use_ssl,
- connection_timeout=5)
- cursor = conn.cursor()
- for sql in sqls:
- cursor.execute(sql)
- res.append(cursor.fetchall())
- logger.info(res)
- except Error as e:
- logger.error(e)
- raise Exception('mysql 连接失败')
- finally:
- if conn is not None and conn.is_connected():
- conn.close()
- return res
- def get_preview_data(self, table_name, limit=100):
- # sql1 = f'SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="{self.database_name}" AND TABLE_NAME="{table_name}"'
- table_schema = self.get_table_schema(table_name)
- sql2 = f"SELECT * FROM {table_name} LIMIT {limit}"
- res = self._execute_sql([sql2])
- logger.info(res)
- return {
- # 'header': flat_map(lambda x: x, res[0]),
- 'header': [str(column).split(':')[1] for column in table_schema],
- 'content': res[0]
- }
- # db_connection = create_engine(self.connection_str)
- # df = pd.read_sql(sql, con=db_connection)
- # # print(df)
- # logger.info(df.head())
- # return df.to_numpy()
- def list_tables(self):
- # table_type = "base table" AND
- sql = f'SELECT table_name FROM information_schema.tables WHERE table_schema="{self.database_name}"'
- res = self._execute_sql([sql])
- return flat_map(lambda x: x, res[0])
- def get_table_schema(self, table_name):
- def handle_col(x):
- line = list(map(lambda s: s.decode('utf-8') if type(s) == type(b'bytes') else str(s), x))
- return [':'.join(line[:3])]
- sql = f'describe `{self.database_name}`.{table_name}'
- logger.info(sql)
- res = self._execute_sql([sql])
- if res:
- res = [[str(i) , *x]for i, x in enumerate(res[0])]
- logger.info(res)
- return flat_map(lambda x: handle_col(x), res)
- else:
- raise Exception('table not found')
|