mysql.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from dataclasses import dataclass
  2. from app.core.datasource.datasource import DataSourceBase
  3. from mysql import connector
  4. from mysql.connector import Error
  5. from sqlalchemy import create_engine
  6. import pandas as pd
  7. from configs import logger
  8. from utils import flat_map
  9. class MysqlDS(DataSourceBase):
  10. def __init__(self, host: str, port: int, username: str, password: str, database_name: str, use_ssl: int, type: str='mysql'):
  11. DataSourceBase.__init__(self, type, host, port, username, password, database_name, )
  12. self.use_ssl = use_ssl
  13. @property
  14. def jdbc_url(self):
  15. jdbc = f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}'
  16. if self.use_ssl == 0:
  17. jdbc = f'{jdbc}?useSSL=false'
  18. return jdbc
  19. @property
  20. def jdbc_driver_class(self):
  21. return 'com.mysql.jdbc.Driver'
  22. @property
  23. def connection_str(self):
  24. return f'mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_name}'
  25. def is_connect(self):
  26. # 判断mysql是否连接成功
  27. conn = None
  28. try:
  29. use_ssl = False if self.use_ssl == 0 else True
  30. conn = connector.connect(host=self.host,
  31. port=self.port,
  32. database=self.database_name,
  33. user=self.username,
  34. password=self.password,
  35. ssl_disabled=not use_ssl)
  36. if conn.is_connected():
  37. logger.info('Connected to MySQL database')
  38. except Error as e:
  39. logger.error(e)
  40. finally:
  41. if conn is not None and conn.is_connected():
  42. conn.close()
  43. return True
  44. else:
  45. return False
  46. def _execute_sql(self, sqls):
  47. conn = None
  48. res = []
  49. try:
  50. use_ssl = False if self.use_ssl == 0 else True
  51. conn = connector.connect(host=self.host,
  52. port=self.port,
  53. database=self.database_name,
  54. user=self.username,
  55. password=self.password,
  56. ssl_disabled=not use_ssl)
  57. cursor = conn.cursor()
  58. for sql in sqls:
  59. cursor.execute(sql)
  60. res.append(cursor.fetchall())
  61. logger.info(res)
  62. except Error as e:
  63. logger.error(e)
  64. finally:
  65. if conn is not None and conn.is_connected():
  66. conn.close()
  67. return res
  68. def get_preview_data(self, table_name, limit=100):
  69. sql1 = f'SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="{self.database_name}" AND TABLE_NAME="{table_name}"'
  70. sql2 = f"SELECT * FROM {table_name} LIMIT {limit}"
  71. res = self._execute_sql([sql1, sql2])
  72. logger.info(res)
  73. return {
  74. 'header': flat_map(lambda x: x, res[0]),
  75. 'content': res[1]
  76. }
  77. # db_connection = create_engine(self.connection_str)
  78. # df = pd.read_sql(sql, con=db_connection)
  79. # # print(df)
  80. # logger.info(df.head())
  81. # return df.to_numpy()
  82. def list_tables(self):
  83. # table_type = "base table" AND
  84. sql = f'SELECT table_name FROM information_schema.tables WHERE table_schema="{self.database_name}"'
  85. res = self._execute_sql([sql])
  86. return flat_map(lambda x: x, res[0])
  87. def get_table_schema(self, table_name):
  88. sql = f'describe `{self.database_name}`.{table_name}'
  89. logger.info(sql)
  90. res = self._execute_sql([sql])
  91. if res:
  92. res = [[str(i) , *x]for i, x in enumerate(res[0])]
  93. logger.info(res)
  94. return flat_map(lambda x: [':'.join(x[:3])], res)
  95. else:
  96. raise Exception('table not found')