server.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from sx_utils.sximage import *
  6. from sx_utils.sxtime import sxtimeit
  7. from sx_utils.sxweb import web_try
  8. from core.ocr import IdCardOcr
  9. import os
  10. app = FastAPI()
  11. origins = ["*"]
  12. app.add_middleware(
  13. CORSMiddleware,
  14. allow_origins=origins,
  15. allow_credentials=True,
  16. allow_methods=["*"],
  17. allow_headers=["*"],
  18. )
  19. # templates = Jinja2Templates(directory='templates')
  20. use_gpu = False
  21. if os.getenv('USE_CUDA') == 'gpu':
  22. use_gpu = True
  23. print(f'use gpu: {use_gpu}')
  24. # 初始化ocr模型和后处理模型
  25. ocr = PaddleOCR(use_angle_cls=True,
  26. rec_model_dir="./idcard_rec_infer/",
  27. det_model_dir="./idcard_det_infer/",
  28. cls_model_dir="idcard_cls_infer",
  29. rec_algorithm='CRNN',
  30. ocr_version='PP-OCRv2',
  31. rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  32. use_gpu=use_gpu,
  33. warmup=True)
  34. @app.get("/ping")
  35. def ping():
  36. return "pong!"
  37. # @app.get("/")
  38. # def home(request: Request):
  39. # ''' Returns html jinja2 template render for home page form
  40. # '''
  41. #
  42. # return templates.TemplateResponse('home.html', {
  43. # "request": request,
  44. # })
  45. class IdCardInfo(BaseModel):
  46. image: str
  47. image_type: str
  48. # /ocr_system/bankcard 银行卡
  49. # /ocr_system/regbook 户口本
  50. # /ocr_system/schoolcert 学信网
  51. @app.post("/ocr_system/idcard")
  52. @sxtimeit
  53. @web_try()
  54. def idcard(request: Request, id_card: IdCardInfo):
  55. image = base64_to_np(id_card.image)
  56. m = IdCardOcr(ocr, image, id_card.image_type)
  57. return m.predict()
  58. if __name__ == '__main__':
  59. import uvicorn
  60. import argparse
  61. parser = argparse.ArgumentParser()
  62. parser.add_argument('--host', default='0.0.0.0')
  63. parser.add_argument('--port', default=8080)
  64. opt = parser.parse_args()
  65. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  66. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)