server.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from blfe_core.direction import AngleDetector
  6. from utils.image import *
  7. from utils.time import timeit
  8. from utils.web import web_try
  9. from blfe_core.ocr import BusinessLicenseOcr
  10. import os
  11. app = FastAPI()
  12. origins = ["*"]
  13. # CORS 跨源资源共享
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=origins,
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. use_gpu = os.getenv('USE_CUDA') == 'gpu'
  22. print(f'use gpu: {use_gpu}')
  23. ocr = PaddleOCR(
  24. det_model_dir="./models/det",
  25. rec_model_dir="./models/rec",
  26. cls_model_dir="./models/cls",
  27. use_gpu=use_gpu,
  28. det_db_unclip_ratio=1.8,
  29. det_db_thresh=0.1,
  30. det_db_box_thresh=0.3,)
  31. # 初始化 角度检测器 对象
  32. ad = AngleDetector(ocr)
  33. # 初始化 ocr识别 对象
  34. m = BusinessLicenseOcr(ocr, ad)
  35. # Get 健康检查
  36. @app.get("/ping")
  37. def ping():
  38. return "pong!"
  39. # 解析传入的 json对象
  40. class BusinessLicenseInfo(BaseModel):
  41. image: str
  42. # Post 接口
  43. @app.post("/ocr_system/business_license")
  44. @timeit
  45. @web_try()
  46. def blfe(request: Request, blfe: BusinessLicenseInfo):
  47. image = base64_to_np(blfe.image)
  48. return m.predict(image)
  49. if __name__ == '__main__':
  50. import uvicorn
  51. import argparse
  52. parser = argparse.ArgumentParser()
  53. parser.add_argument('--host', default='0.0.0.0')
  54. parser.add_argument('--port', default=8080)
  55. opt = parser.parse_args()
  56. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  57. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)