server.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from sx_utils.sximage import *
  6. from sx_utils.sxtime import sxtimeit
  7. from sx_utils.sxweb import web_try
  8. from core.ocr import IdCardOcr
  9. import os
  10. app = FastAPI()
  11. origins = ["*"]
  12. app.add_middleware(
  13. CORSMiddleware,
  14. allow_origins=origins,
  15. allow_credentials=True,
  16. allow_methods=["*"],
  17. allow_headers=["*"],
  18. )
  19. # templates = Jinja2Templates(directory='templates')
  20. use_gpu = False
  21. if os.getenv('USE_CUDA') == 'gpu':
  22. use_gpu = True
  23. print(f'use gpu: {use_gpu}')
  24. # 初始化ocr模型和后处理模型
  25. ocr = PaddleOCR(use_angle_cls=True,
  26. rec_model_dir="./ch_ppocr_server_v2.0_rec_infer/",
  27. det_model_dir="./ch_ppocr_server_v2.0_det_infer/",
  28. cls_model_dir="./huko_cls_infer/",
  29. rec_algorithm='CRNN',
  30. ocr_version='PP-OCRv2',
  31. rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  32. use_gpu=use_gpu,
  33. warmup=True)
  34. # ocr = PaddleOCR(use_angle_cls=True,
  35. # use_gpu=use_gpu)
  36. m = IdCardOcr(ocr)
  37. @app.get("/ping")
  38. def ping():
  39. return "pong!"
  40. # @app.get("/")
  41. # def home(request: Request):
  42. # ''' Returns html jinja2 template render for home page form
  43. # '''
  44. #
  45. # return templates.TemplateResponse('home.html', {
  46. # "request": request,
  47. # })
  48. class ParamInfo(BaseModel):
  49. image: str
  50. image_type: str
  51. # /ocr_system/bankcard 银行卡
  52. # /ocr_system/regbook 户口本
  53. # /ocr_system/schoolcert 学信网
  54. @app.post("/ocr_system/regbook")
  55. @web_try()
  56. @sxtimeit
  57. def detect(request: Request, param: ParamInfo):
  58. image = base64_to_np(param.image)
  59. return m.predict(image, param.image_type)
  60. if __name__ == '__main__':
  61. import uvicorn
  62. import argparse
  63. parser = argparse.ArgumentParser()
  64. parser.add_argument('--host', default='0.0.0.0')
  65. parser.add_argument('--port', default=8080)
  66. opt = parser.parse_args()
  67. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  68. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)