Zhang Li преди 2 години
ревизия
b6a68d2490

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+__pycache__/
+/.idea/
+/.vscode/

+ 0 - 0
app/__init__.py


+ 0 - 0
app/core/__init__.py


+ 66 - 0
app/core/datasource.py

@@ -0,0 +1,66 @@
+
+from dataclasses import dataclass
+import  base64
+from email.mime import base
+
+
+@dataclass
+class DataSourceBase:
+    type: str
+    host: str
+    port: int
+    username: str
+    password: str
+    database_name: str
+
+    @property
+    def jdbc_username(self):
+        return base64.encodebytes(self.username.encode('utf-8')).decode('utf-8')
+
+
+    @property
+    def jdbc_password(self):
+        return base64.encodebytes(self.password.encode('utf-8')).decode('utf-8')
+
+    @property
+    def jdbc_url(self):
+        pass
+
+    @property
+    def jdbc_driver_class(self):
+        pass
+
+
+class MysqlDS(DataSourceBase):
+    type = 'mysql'
+
+    @property
+    def jdbc_url(self):
+        return f'jdbc:mysql://{self.host}:{self.port}/{self.database_name}'
+
+    @property
+    def jdbc_driver_class(self):
+        return 'com.mysql.jdbc.Driver'
+
+
+class HiveDS(DataSourceBase):
+    type = 'hive'
+
+    @property
+    def jdbc_url(self):
+        return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}'
+
+    @property
+    def jdbc_driver_class(self):
+        return 'org.apache.hive.jdbc.HiveDriver'
+
+
+class DataSrouceFactory:
+    @staticmethod
+    def create(ds_type: str, ds_config: dict):
+        if ds_type == 'mysql':
+            return MysqlDS(**ds_config, type=ds_type)
+        elif ds_type == 'hive':
+            return HiveDS(**ds_config, type=ds_type)
+        else:
+            raise Exception('不支持的数据源类型')

+ 1 - 0
app/crud/__init__.py

@@ -0,0 +1 @@
+from app.crud.job_jdbc_datasource import *

+ 76 - 0
app/crud/job_jdbc_datasource.py

@@ -0,0 +1,76 @@
+import time
+from typing import List
+from sqlalchemy.orm import Session
+from app.core.datasource import DataSrouceFactory
+import app.schemas as schemas
+import  app.models as models
+
+def _format_datasource(item: schemas.JobJdbcDatasourceBase):
+    host, port = item.jdbc_url.split(':')
+    if not host or not port:
+        raise Exception('jdbc_url is invalid')
+
+    ds = DataSrouceFactory.create(item.datasource, {'port': port, 'host': host, 'username': item.jdbc_username, 'password': item.jdbc_password, 'database_name': item.database_name})
+    item.jdbc_url = ds.jdbc_url
+    item.jdbc_username = ds.jdbc_username
+    item.jdbc_password = ds.jdbc_password
+    return ds, item
+
+def create_job_jdbc_datasource(db: Session, item: schemas.JobJdbcDatasourceCreate):
+    ds, item = _format_datasource(item)
+
+    #
+    create_time: int = int(time.time())
+    db_item = models.JobJdbcDatasource(**item.dict(), **{
+        'status': 1,
+        'create_time': create_time,
+        'create_by': 'admin',
+        'update_time': create_time,
+        'update_by': 'admin',
+        'jdbc_driver_class': ds.jdbc_driver_class
+    })
+    db.add(db_item)
+    db.commit()
+    db.refresh(db_item)
+    print(db_item)
+    return db_item
+
+
+def get_job_jdbc_datasources(db: Session, skip: int = 0, limit: int = 20):
+    def _decode(url, datasource, database_name):
+        return url.replace('jdbc:', '').replace(f'{datasource}://', '').replace(f'/{database_name}', '')
+
+    res: List[models.JobJdbcDatasource] = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.status==1).all()
+    for item in res:
+        item.jdbc_url = _decode(item.jdbc_url, item.datasource, item.database_name)
+    return res
+
+    # return db.query(models.JobJdbcDatasource).offset(skip).limit(limit).all()
+
+
+def update_job_jdbc_datasources(db: Session, ds_id: int, update_item: schemas.JobJdbcDatasourceUpdate):
+    ds, update_item = _format_datasource(update_item)
+
+    db_item = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first()
+    if not db_item:
+        raise Exception('JobJdbcDatasource not found')
+    update_dict = update_item.dict(exclude_unset=True)
+    for k, v in update_dict.items():
+        setattr(db_item, k, v)
+    db_item.jdbc_driver_class = ds.jdbc_driver_class
+    db_item.update_time = int(time.time())
+    db_item.update_by = 'admin1'
+    db.commit()
+    db.flush()
+    db.refresh(db_item)
+    return db_item
+
+def delete_job_jdbc_datasource(db: Session, ds_id: int):
+    db_item = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first()
+    if not db_item:
+        raise Exception('JobJdbcDatasource not found')
+    db_item.status = 0
+    db.commit()
+    db.flush()
+    db.refresh(db_item)
+    return db_item

+ 1 - 0
app/models/__init__.py

@@ -0,0 +1 @@
+from app.models.job_jdbc_datasource import *

+ 25 - 0
app/models/database.py

@@ -0,0 +1,25 @@
+ # database.py
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+
+USER = 'root'
+PWD = 'happylay'
+DB_NAME = 'datax_web_dev'
+HOST = '192.168.199.107'
+PORT = '10086'
+
+SQLALCHEMY_DATABASE_URL = f'mysql+mysqlconnector://{USER}:{PWD}@{HOST}:{PORT}/{DB_NAME}?charset=utf8&auth_plugin=mysql_native_password'
+engine = create_engine(
+    SQLALCHEMY_DATABASE_URL, pool_pre_ping=True
+)
+
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+
+Base = declarative_base()
+
+class BaseModel(Base):
+    __abstract__ = True
+    def to_dict(self):
+        return {c.name: getattr(self, c.name) for c in self.__table__.columns}
+

+ 36 - 0
app/models/job_jdbc_datasource.py

@@ -0,0 +1,36 @@
+from sqlalchemy import Boolean, Column, ForeignKey, Integer, String
+
+from app.models.database import BaseModel
+
+
+class JobJdbcDatasource(BaseModel):
+    __tablename__ = "job_jdbc_datasource"
+
+    id = Column(Integer, primary_key=True, index=True)
+    # 数据源名称
+    datasource_name = Column(String, nullable=False, unique=True, index=True)
+    # 数据源
+    datasource = Column(String, nullable=False)
+    # 数据库名
+    database_name = Column(String)
+    # 数据库用户名
+    jdbc_username = Column(String)
+    # 数据库密码
+    jdbc_password = Column(String)
+    # jdbc url
+    jdbc_url = Column(String)
+    # jdbc driver
+    jdbc_driver_class = Column(String)
+    # 状态: 0 删除 1 启用 2 禁用
+    status = Column(Integer, default=1, nullable=False)
+    # 创建时间
+    create_time = Column(Integer)
+    # 创建人
+    create_by = Column(String)
+    # 更新时间
+    update_time = Column(Integer)
+    # 更新人
+    update_by = Column(String)
+    # 备注
+    comments = Column(String)
+

+ 0 - 0
app/routers/__init__.py


+ 5 - 0
app/routers/job_jdbc_datasource.py

@@ -0,0 +1,5 @@
+from fastapi import APIRouter
+
+router = APIRouter()
+
+

+ 1 - 0
app/schemas/__init__.py

@@ -0,0 +1 @@
+from app.schemas.job_jdbc_datasouce import *

+ 53 - 0
app/schemas/job_jdbc_datasouce.py

@@ -0,0 +1,53 @@
+from typing import List
+
+from pydantic import BaseModel
+
+
+class JobJdbcDatasourceBase(BaseModel):
+    # 数据源名称
+    datasource_name: str
+    # 数据源
+    datasource: str
+    # 数据库名
+    database_name: str
+    # 数据库用户名
+    jdbc_username: str
+    # 数据库密码
+    jdbc_password: str
+    # jdbc url
+    jdbc_url: str
+    # 备注
+    comments: str
+    class Config:
+        schema_extra = {
+            "example": {
+                "datasource_name": 'test',
+                "datasource": "mysql",
+                "database_name": 'datax_web',
+                "jdbc_username": 'root',
+                "jdbc_password": 'happylay',
+                "jdbc_url": '192.168.199.107:10086',
+                "comments": 'This is a very nice Item'
+            }
+        }
+
+
+class JobJdbcDatasourceCreate(JobJdbcDatasourceBase):
+    pass
+
+
+class JobJdbcDatasourceUpdate(JobJdbcDatasourceBase):
+    pass
+
+class JobJdbcDatasource(JobJdbcDatasourceBase):
+    id: int
+    status: int
+    create_time: int
+    create_by: str
+    update_time: int
+    update_by: str
+    jdbc_url: str
+    jdbc_driver_class: str
+
+    class Config:
+        orm_mode = True

+ 0 - 0
configs/__init__.py


+ 10 - 0
configs/database.py

@@ -0,0 +1,10 @@
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+
+SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
+
+engine = create_engine(
+    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
+)
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

+ 1 - 0
requirements.txt

@@ -0,0 +1 @@
+fastapi_pagination=0.9.3

+ 12 - 0
run.py

@@ -0,0 +1,12 @@
+
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 86 - 0
server.py

@@ -0,0 +1,86 @@
+from asyncio.constants import DEBUG_STACK_DEPTH
+from typing import List
+
+from fastapi import Depends, FastAPI, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from sqlalchemy.orm import Session
+from app import schemas
+
+from app.models.database import SessionLocal, engine, Base
+import app.crud as crud
+from utils.sx_time import sxtimeit
+from utils.sx_web import web_try
+from fastapi_pagination import Page, add_pagination, paginate, Params
+
+Base.metadata.create_all(bind=engine)
+
+app = FastAPI()
+
+
+
+# CORS 跨源资源共享
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+# Dependency
+def get_db():
+    try:
+        db = SessionLocal()
+        yield db
+    finally:
+        db.close()
+
+
+# Get 健康检查
+@app.get("/ping")
+def ping():
+    return "pong!"
+
+
+
+
+@app.post("/datasource/")
+@web_try()
+@sxtimeit
+def create_datasource(ds: schemas.JobJdbcDatasourceCreate, db: Session = Depends(get_db)):
+    return crud.create_job_jdbc_datasource(db, ds)
+
+
+
+@app.get("/datasource/")
+@web_try()
+@sxtimeit
+def get_datasources(params: Params=Depends(), db: Session = Depends(get_db)):
+    return paginate(crud.get_job_jdbc_datasources(db), params)
+
+@app.put("/datasource/{ds_id}")
+@web_try()
+@sxtimeit
+def update_datasource(ds_id: int, ds: schemas.JobJdbcDatasourceUpdate, db: Session = Depends(get_db)):
+    return crud.update_job_jdbc_datasources(db, ds_id, ds)
+
+@app.delete("/datasource/{ds_id}")
+@web_try()
+@sxtimeit
+def delete_job_jdbc_datasource(ds_id: int, db: Session = Depends(get_db)):
+    return crud.delete_job_jdbc_datasource(db, ds_id)
+
+add_pagination(app)
+
+
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 2 - 0
utils/__init__.py

@@ -0,0 +1,2 @@
+from utils.sx_web import *
+from utils.sx_time import *

+ 125 - 0
utils/sx_time.py

@@ -0,0 +1,125 @@
+# coding=utf-8
+# Powered by SoaringNova Technology Company
+import errno
+import os
+import signal
+import time
+from collections import defaultdict
+from functools import wraps
+
+timer_counts = defaultdict(int)
+
+
+class SXTimeoutError(Exception):
+    pass
+
+
+def sxtimeout(seconds=10, error_message=os.strerror(errno.ETIME)):
+    def decorator(func):
+        def _handle_timeout(signum, frame):
+            raise SXTimeoutError(error_message)
+
+        def wrapper(*args, **kwargs):
+            signal.signal(signal.SIGALRM, _handle_timeout)
+            signal.alarm(seconds)
+            try:
+                result = func(*args, **kwargs)
+            finally:
+                signal.alarm(0)
+            return result
+
+        return wraps(func)(wrapper)
+
+    return decorator
+
+
+class SXTIMELIMIT:
+    def __init__(self, limit_time=0):
+        self.st = None
+        self.et = None
+        self.limit_time = limit_time
+
+    def __enter__(self):
+        self.st = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.et = time.time()
+        dt = self.limit_time - (self.et - self.st) * 1000
+        if dt > 0: time.sleep(float(dt) / 1000)
+
+
+class SXTIMER:
+    total_time = {}  # type: dict
+
+    def __init__(self, tag='', enable_total=False, threshold_ms=0):
+        self.st = None
+        self.et = None
+        self.tag = tag
+        # self.tag = tag if not hasattr(g,'request_id') else '{} {}'.format(getattr(g,'request_id'),tag)
+
+        self.thr = threshold_ms
+        self.enable_total = enable_total
+        if self.enable_total:
+            if self.tag not in self.total_time.keys():
+                self.total_time[self.tag] = []
+
+    def __enter__(self):
+        self.st = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.et = time.time()
+        dt = (self.et - self.st) * 1000
+        if self.enable_total:
+            self.total_time[self.tag].append(dt)
+
+        if dt > self.thr:
+            print("{}: {}s".format(self.tag, round(dt / 1000, 4)))
+
+    @staticmethod
+    def output():
+        for k, v in SXTIMER.total_time.items():
+            print('{} : {}s, avg{}s'.format(k, round(sum(v) / 1000, 2), round(sum(v) / len(v) / 1000, 2)))
+
+
+def sxtimeit(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        st = time.time()
+        ret = func(*args, **kwargs)
+        dt = time.time() - st
+        endpoint = '{}.{}'.format(func.__module__, func.__name__)
+        timer_counts[endpoint] += 1
+        print('{}[{}] finished, exec {}s'.format(endpoint, '%05d' % timer_counts[endpoint], round(dt, 4)))
+        return ret
+
+    return wrapper  # 返回
+
+
+# def sxtimeit(func):
+#     @wraps(func)
+#     def wrapper(*args, **kwargs):
+#         endpoint = '{}.{}'.format(func.__module__, func.__name__)
+#         setattr(g,'request_id','{}[{}]'.format(endpoint,'%05d' % timer_counts[endpoint]))
+#         timer_counts[endpoint] += 1
+#         st = time.time()
+#         ret = func(*args, **kwargs)
+#         dt = time.time() - st
+#         print ('{} finished, exec {}s'.format(getattr(g,'request_id'), round(dt, 4)))
+#         return ret
+#
+#     return wrapper  # 返回
+
+def t2date(t):
+    import datetime
+    date = datetime.datetime.fromtimestamp(t)
+    return '{}_{}_{}_{}:{}:{}'.format(date.year, date.month, date.day, date.hour, date.minute, date.second)
+
+
+def day_begin(t):
+    dsecs = 24 * 3600
+    return (int(t) + 8 * 3600) // dsecs * dsecs - 8 * 3600
+
+
+def hour_begin(t):
+    hsecs = 3600
+    return (int(t) + 8 * 3600) // hsecs * hsecs - 8 * 3600

+ 43 - 0
utils/sx_web.py

@@ -0,0 +1,43 @@
+import traceback
+import numpy as np
+from decorator import decorator
+from fastapi.responses import JSONResponse
+
+
+def json_compatible(data):
+    if isinstance(data,dict):
+        return {k:json_compatible(v) for k,v in data.items()}
+    if isinstance(data,bytes):
+        return str(data)
+    if isinstance(data,np.ndarray):
+        return data.tolist()
+    return data
+
+def web_try(exception_ret=None):
+    @decorator
+    def f(func, *args, **kwargs):
+        error_code = 200
+        ret = None
+        msg = ''
+        try:
+            ret = func(*args, **kwargs)
+        except Exception as e:
+            msg = traceback.format_exc()
+            if len(e.args) > 0 and isinstance(e.args[0], int):
+                error_code = e.args[0]
+            else:
+                error_code = 400
+            print('--------------------------------')
+            print ('Get Exception in web try :( \n{}\n'.format(msg))
+            print('--------------------------------')
+            if callable(exception_ret):
+                ret = exception_ret()
+            else:
+                ret = exception_ret
+        finally:
+            if ret is not None and isinstance(ret, JSONResponse):
+                return ret
+            return json_compatible({"code": error_code,
+                                    "data": ret,
+                                    "msg": msg.split('\n')[-2] if msg is not '' else msg})
+    return f