server.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from core.direction import AngleDetector
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. from core.ocr import IdCardOcr
  10. import os
  11. # 导入一些包
  12. app = FastAPI()
  13. origins = ["*"]
  14. # CORS 跨源资源共享
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=origins,
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. # templates = Jinja2Templates(directory='templates')
  23. use_gpu = False
  24. if os.getenv('USE_CUDA') == 'gpu':
  25. use_gpu = True
  26. print(f'use gpu: {use_gpu}')
  27. # 初始化ocr模型和后处理模型
  28. # 分类
  29. # ocr = PaddleOCR(use_angle_cls=True,
  30. # # 方向
  31. # rec_model_dir="./idcard_rec_infer/",
  32. # det_model_dir="./idcard_det_infer/",
  33. # cls_model_dir="idcard_cls_infer",
  34. # # 识别
  35. # rec_algorithm='CRNN',
  36. # ocr_version='PP-OCRv2',
  37. # # 中文字典
  38. # rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  39. # use_gpu=use_gpu,
  40. # # 预训练-->效果不明显
  41. # # 网络不够大、不够深
  42. # # 数据集普遍较小,batch size普遍较小
  43. # warmup=True)
  44. # ocr = PaddleOCR(use_angle_cls=True,
  45. # use_gpu=use_gpu)
  46. ocr = PaddleOCR(use_angle_cls=True,
  47. rec_model_dir='./ch_ppocr_server_v2.0_rec_infer',
  48. det_model_dir='./ch_ppocr_server_v2.0_det_infer',
  49. cls_model_dir='./idcard_cls_infer',
  50. ocr_version='PP-OCRv2',
  51. rec_algorithm='CRNN',
  52. use_gpu=use_gpu,
  53. det_db_unclip_ratio=2.5,
  54. det_db_thresh=0.1,
  55. det_db_box_thresh=0.3,
  56. warmup=True)
  57. # 初始化 角度检测器 对象
  58. ad = AngleDetector(ocr)
  59. # 初始化 身份证ocr识别 对象
  60. m = IdCardOcr(ocr, ad)
  61. # Get 健康检查
  62. @app.get("/ping")
  63. def ping():
  64. return "pong!"
  65. # 解析传入的 json对象
  66. class IdCardInfo(BaseModel):
  67. image: str
  68. image_type: str
  69. # /ocr_system/bankcard 银行卡
  70. # /ocr_system/regbook 户口本
  71. # /ocr_system/schoolcert 学信网
  72. # Post 接口
  73. # 计算耗时
  74. # 异常处理
  75. @app.post("/ocr_system/idcard")
  76. @sxtimeit
  77. @web_try()
  78. # 传入=> base64码 -> np
  79. # 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
  80. def idcard(request: Request, id_card: IdCardInfo):
  81. image = base64_to_np(id_card.image)
  82. return m.predict(image, id_card.image_type)
  83. @app.post("/ocr_system/orientation")
  84. @sxtimeit
  85. @web_try()
  86. # 传入=> base64码 -> np
  87. # 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
  88. def detect_angle(request: Request, id_card: IdCardInfo):
  89. image = base64_to_np(id_card.image)
  90. angle, _ = ad.detect_angle(image, id_card.image_type)
  91. return {'orientation': angle}
  92. if __name__ == '__main__':
  93. import uvicorn
  94. import argparse
  95. parser = argparse.ArgumentParser()
  96. parser.add_argument('--host', default='0.0.0.0')
  97. parser.add_argument('--port', default=8080)
  98. opt = parser.parse_args()
  99. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  100. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)