server.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from core.direction import AngleDetector
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. from core.ocr import CetOcr
  10. import os
  11. # 导入一些包
  12. app = FastAPI()
  13. origins = ["*"]
  14. # CORS 跨源资源共享
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=origins,
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. # templates = Jinja2Templates(directory='templates')
  23. use_gpu = False
  24. if os.getenv('USE_CUDA') == 'gpu':
  25. use_gpu = True
  26. print(f'use gpu: {use_gpu}')
  27. # 初始化ocr模型和后处理模型
  28. # 分类
  29. # ocr = PaddleOCR(use_angle_cls=True,
  30. # # 方向
  31. # rec_model_dir="./idcard_rec_infer/",
  32. # det_model_dir="./idcard_det_infer/",
  33. # cls_model_dir="idcard_cls_infer",
  34. # # 识别
  35. # rec_algorithm='CRNN',
  36. # ocr_version='PP-OCRv2',
  37. # # 中文字典
  38. # rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  39. # use_gpu=use_gpu,
  40. # # 预训练-->效果不明显
  41. # # 网络不够大、不够深
  42. # # 数据集普遍较小,batch size普遍较小
  43. # warmup=True)
  44. # ocr = PaddleOCR(use_angle_cls=True,
  45. # use_gpu=use_gpu)
  46. ocr = PaddleOCR(use_angle_cls=True,
  47. use_gpu=use_gpu,
  48. det_db_unclip_ratio=2.5,
  49. det_db_thresh=0.1,
  50. det_db_box_thresh=0.3,
  51. warmup=True)
  52. #
  53. # ocr = PaddleOCR(use_angle_cls=True,
  54. # rec_model_dir='./ch_ppocr_server_v2.0_rec_infer',
  55. # det_model_dir='./ch_ppocr_server_v2.0_det_infer',
  56. # cls_model_dir='./idcard_cls_infer',
  57. # ocr_version='PP-OCRv2',
  58. # rec_algorithm='CRNN',
  59. # use_gpu=use_gpu,
  60. # det_db_unclip_ratio=2.5,
  61. # det_db_thresh=0.1,
  62. # det_db_box_thresh=0.3,
  63. # warmup=True)
  64. # 初始化 角度检测器 对象
  65. ad = AngleDetector(ocr)
  66. # 初始化 身份证ocr识别 对象
  67. m = CetOcr(ocr, ad)
  68. # Get 健康检查
  69. @app.get("/ping")
  70. def ping():
  71. return "pong!"
  72. # 解析传入的 json对象
  73. class CetInfo(BaseModel):
  74. image: str
  75. # /ocr_system/bankcard 银行卡
  76. # /ocr_system/regbook 户口本
  77. # /ocr_system/schoolcert 学信网
  78. # Post 接口
  79. # 计算耗时
  80. # 异常处理
  81. @app.post("/ocr_system/cet")
  82. @sxtimeit
  83. @web_try()
  84. # 传入=> base64码 -> np
  85. # 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
  86. def cet(request: Request, cer: CetInfo):
  87. image = base64_to_np(cer.image)
  88. return m.predict(image)
  89. if __name__ == '__main__':
  90. import uvicorn
  91. import argparse
  92. parser = argparse.ArgumentParser()
  93. parser.add_argument('--host', default='0.0.0.0')
  94. parser.add_argument('--port', default=8080)
  95. opt = parser.parse_args()
  96. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  97. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)