mysql.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from dataclasses import dataclass
  2. from app.core.ds.datasource import DataSourceBase
  3. from mysql import connector
  4. from mysql.connector import Error
  5. from sqlalchemy import create_engine
  6. import pandas as pd
  7. from configs import logger
  8. from utils import flat_map
  9. @dataclass
  10. class MysqlDS(DataSourceBase):
  11. type = 'mysql'
  12. @property
  13. def jdbc_url(self):
  14. return f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}'
  15. @property
  16. def jdbc_driver_class(self):
  17. return 'com.mysql.jdbc.Driver'
  18. @property
  19. def connection_str(self):
  20. return f'mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database_name}'
  21. def is_connect(self):
  22. # 判断mysql是否连接成功
  23. conn = None
  24. try:
  25. conn = connector.connect(host=self.host,
  26. port=self.port,
  27. database=self.database_name,
  28. user=self.username,
  29. password=self.password)
  30. if conn.is_connected():
  31. logger.info('Connected to MySQL database')
  32. except Error as e:
  33. logger.error(e)
  34. finally:
  35. if conn is not None and conn.is_connected():
  36. conn.close()
  37. return True
  38. else:
  39. return False
  40. def _execute_sql(self, sqls):
  41. conn = None
  42. res = []
  43. try:
  44. conn = connector.connect(host=self.host,
  45. port=self.port,
  46. database=self.database_name,
  47. user=self.username,
  48. password=self.password)
  49. cursor = conn.cursor()
  50. for sql in sqls:
  51. cursor.execute(sql)
  52. res.append(cursor.fetchall())
  53. logger.info(res)
  54. except Error as e:
  55. logger.error(e)
  56. finally:
  57. if conn is not None and conn.is_connected():
  58. conn.close()
  59. return res
  60. def get_preview_data(self, table_name, limit=100):
  61. sql1 = f'SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA="{self.database_name}" AND TABLE_NAME="{table_name}"'
  62. sql2 = f"SELECT * FROM {table_name} LIMIT {limit}"
  63. res = self._execute_sql([sql1, sql2])
  64. logger.info(res)
  65. return {
  66. 'header': flat_map(lambda x: x, res[0]),
  67. 'content': res[1]
  68. }
  69. # db_connection = create_engine(self.connection_str)
  70. # df = pd.read_sql(sql, con=db_connection)
  71. # # print(df)
  72. # logger.info(df.head())
  73. # return df.to_numpy()
  74. def list_tables(self):
  75. sql = f'SELECT table_name FROM information_schema.tables WHERE table_type = "base table" AND table_schema="{self.database_name}"'
  76. res = self._execute_sql([sql])
  77. return flat_map(lambda x: x, res[0])
  78. def get_table_schema(self, table_name):
  79. sql = f'describe {self.database_name}.{table_name}'
  80. res = self._execute_sql([sql])
  81. if res:
  82. res = [[str(i) , *x]for i, x in enumerate(res[0])]
  83. logger.info(res)
  84. return flat_map(lambda x: [':'.join(x[:3])], res)
  85. else:
  86. raise Exception('table not found')