server.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from core.direction import AngleDetector
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. from core.ocr import IdCardOcr
  10. import os
  11. app = FastAPI()
  12. origins = ["*"]
  13. app.add_middleware(
  14. CORSMiddleware,
  15. allow_origins=origins,
  16. allow_credentials=True,
  17. allow_methods=["*"],
  18. allow_headers=["*"],
  19. )
  20. # templates = Jinja2Templates(directory='templates')
  21. use_gpu = False
  22. if os.getenv('USE_CUDA') == 'gpu':
  23. use_gpu = True
  24. print(f'use gpu: {use_gpu}')
  25. # 初始化ocr模型和后处理模型
  26. ocr = PaddleOCR(use_angle_cls=True,
  27. rec_model_dir="./idcard_rec_infer/",
  28. det_model_dir="./idcard_det_infer/",
  29. cls_model_dir="idcard_cls_infer",
  30. rec_algorithm='CRNN',
  31. ocr_version='PP-OCRv2',
  32. rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  33. use_gpu=use_gpu,
  34. warmup=True)
  35. # ocr = PaddleOCR(use_angle_cls=True,
  36. # use_gpu=use_gpu,
  37. # warmup=True)
  38. ad = AngleDetector(ocr)
  39. m = IdCardOcr(ocr, ad)
  40. @app.get("/ping")
  41. def ping():
  42. return "pong!"
  43. # @app.get("/")
  44. # def home(request: Request):
  45. # ''' Returns html jinja2 template render for home page form
  46. # '''
  47. #
  48. # return templates.TemplateResponse('home.html', {
  49. # "request": request,
  50. # })
  51. class IdCardInfo(BaseModel):
  52. image: str
  53. image_type: str
  54. # /ocr_system/bankcard 银行卡
  55. # /ocr_system/regbook 户口本
  56. # /ocr_system/schoolcert 学信网
  57. @app.post("/ocr_system/idcard")
  58. @sxtimeit
  59. @web_try()
  60. def idcard(request: Request, id_card: IdCardInfo):
  61. image = base64_to_np(id_card.image)
  62. return m.predict(image, id_card.image_type)
  63. if __name__ == '__main__':
  64. import uvicorn
  65. import argparse
  66. parser = argparse.ArgumentParser()
  67. parser.add_argument('--host', default='0.0.0.0')
  68. parser.add_argument('--port', default=8080)
  69. opt = parser.parse_args()
  70. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  71. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)