server.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. import cv2
  6. from core.direction import AngleDetector
  7. from utils.image import *
  8. from utils.time import timeit
  9. from utils.web import web_try
  10. from core.ocr import IdCardOcr
  11. import os
  12. app = FastAPI()
  13. origins = ["*"]
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=origins,
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. use_gpu = os.getenv('USE_CUDA') == 'gpu'
  22. print(f'use gpu: {use_gpu}')
  23. # 初始化ocr模型和后处理模型
  24. ocr = PaddleOCR(use_angle_cls=True,
  25. rec_model_dir="./models/rec/",
  26. det_model_dir="./models/det/",
  27. cls_model_dir="./models/cls/",
  28. rec_algorithm='CRNN',
  29. ocr_version='PP-OCRv2',
  30. lang="ch",
  31. use_gpu=use_gpu,
  32. det_db_unclip_ratio=1.7,
  33. warmup=True)
  34. # 初始化 角度检测器 对象
  35. ad = AngleDetector(ocr)
  36. m = IdCardOcr(ocr, ad)
  37. @app.get("/ping")
  38. def ping():
  39. return "pong!"
  40. class ParamInfo(BaseModel):
  41. image: str
  42. image_type: str
  43. @app.post("/ocr_system/regbook")
  44. @web_try()
  45. @timeit
  46. def detect(request: Request, param: ParamInfo):
  47. image = base64_to_np(param.image)
  48. if image.size > 36000000:
  49. image = cv2.resize(image, (int(image.shape[0]*0.8), int(image.shape[1]*0.8)), interpolation=cv2.INTER_CUBIC)
  50. return m.predict(image, param.image_type)
  51. if __name__ == '__main__':
  52. import uvicorn
  53. import argparse
  54. parser = argparse.ArgumentParser()
  55. parser.add_argument('--host', default='0.0.0.0')
  56. parser.add_argument('--port', default=8080)
  57. opt = parser.parse_args()
  58. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  59. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)