server.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from core.direction import AngleDetector
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. from core.ocr import CetOcr
  10. import os
  11. # 导入一些包
  12. app = FastAPI()
  13. origins = ["*"]
  14. # CORS 跨源资源共享
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=origins,
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. # templates = Jinja2Templates(directory='templates')
  23. use_gpu = False
  24. if os.getenv('USE_CUDA') == 'gpu':
  25. use_gpu = True
  26. print(f'use gpu: {use_gpu}')
  27. # 初始化ocr模型和后处理模型
  28. # 分类
  29. # ocr = PaddleOCR(use_angle_cls=True,
  30. # # 方向
  31. # rec_model_dir="./idcard_rec_infer/",
  32. # det_model_dir="./idcard_det_infer/",
  33. # cls_model_dir="idcard_cls_infer",
  34. # # 识别
  35. # rec_algorithm='CRNN',
  36. # ocr_version='PP-OCRv2',
  37. # # 中文字典
  38. # rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  39. # use_gpu=use_gpu,
  40. # # 预训练-->效果不明显
  41. # # 网络不够大、不够深
  42. # # 数据集普遍较小,batch size普遍较小
  43. # warmup=True)
  44. # ocr = PaddleOCR(use_angle_cls=True,
  45. # use_gpu=use_gpu)
  46. # ocr = PaddleOCR(use_angle_cls=True,
  47. # use_gpu=use_gpu,
  48. # det_db_unclip_ratio=2.5,
  49. # det_db_thresh=0.1,
  50. # det_db_box_thresh=0.3,
  51. # warmup=True)
  52. #
  53. ocr = PaddleOCR(use_angle_cls=True,
  54. rec_model_dir='./server_model/ch_ppocr_server_v2.0_rec_infer/',
  55. det_model_dir='./server_model/ch_ppocr_server_v2.0_det_infer/',
  56. ocr_version='PP-OCRv2',
  57. rec_algorithm='CRNN',
  58. use_gpu=use_gpu,
  59. det_db_unclip_ratio=2.5,
  60. det_db_thresh=0.1,
  61. det_db_box_thresh=0.3,
  62. warmup=True)
  63. # 初始化 角度检测器 对象
  64. ad = AngleDetector(ocr)
  65. # 初始化 身份证ocr识别 对象
  66. m = CetOcr(ocr, ad)
  67. # Get 健康检查
  68. @app.get("/ping")
  69. def ping():
  70. return "pong!"
  71. # 解析传入的 json对象
  72. class CetInfo(BaseModel):
  73. image: str
  74. # /ocr_system/bankcard 银行卡
  75. # /ocr_system/regbook 户口本
  76. # /ocr_system/schoolcert 学信网
  77. # Post 接口
  78. # 计算耗时
  79. # 异常处理
  80. @app.post("/ocr_system/cet")
  81. @sxtimeit
  82. @web_try()
  83. # 传入=> base64码 -> np
  84. # 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
  85. def cet(request: Request, cer: CetInfo):
  86. image = base64_to_np(cer.image)
  87. return m.predict(image)
  88. if __name__ == '__main__':
  89. import uvicorn
  90. import argparse
  91. parser = argparse.ArgumentParser()
  92. parser.add_argument('--host', default='0.0.0.0')
  93. parser.add_argument('--port', default=8080)
  94. opt = parser.parse_args()
  95. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  96. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)