server.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from pydantic import BaseModel
  4. from paddleocr import PaddleOCR
  5. from core.direction import AngleDetector
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. from core.ocr import IdCardOcr
  10. import os
  11. # 导入一些包
  12. app = FastAPI()
  13. origins = ["*"]
  14. # CORS 跨源资源共享
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=origins,
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. # templates = Jinja2Templates(directory='templates')
  23. use_gpu = False
  24. if os.getenv('USE_CUDA') == 'gpu':
  25. use_gpu = True
  26. print(f'use gpu: {use_gpu}')
  27. # 初始化ocr模型和后处理模型
  28. # 分类
  29. # ocr = PaddleOCR(use_angle_cls=True,
  30. # # 方向
  31. # rec_model_dir="./idcard_rec_infer/",
  32. # det_model_dir="./idcard_det_infer/",
  33. # cls_model_dir="idcard_cls_infer",
  34. # # 识别
  35. # rec_algorithm='CRNN',
  36. # ocr_version='PP-OCRv2',
  37. # # 中文字典
  38. # rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
  39. # use_gpu=use_gpu,
  40. # # 预训练-->效果不明显
  41. # # 网络不够大、不够深
  42. # # 数据集普遍较小,batch size普遍较小
  43. # warmup=True)
  44. ocr = PaddleOCR(use_angle_cls=True,
  45. use_gpu=use_gpu)
  46. # 初始化 角度检测器 对象
  47. ad = AngleDetector(ocr)
  48. # 初始化 身份证ocr识别 对象
  49. m = IdCardOcr(ocr, ad)
  50. # Get 健康检查
  51. @app.get("/ping")
  52. def ping():
  53. return "pong!"
  54. # 解析传入的 json对象
  55. class IdCardInfo(BaseModel):
  56. image: str
  57. image_type: str
  58. # /ocr_system/bankcard 银行卡
  59. # /ocr_system/regbook 户口本
  60. # /ocr_system/schoolcert 学信网
  61. # Post 接口
  62. # 计算耗时
  63. # 异常处理
  64. @app.post("/ocr_system/idcard")
  65. @sxtimeit
  66. @web_try()
  67. # 传入=> base64码 -> np
  68. # 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
  69. def idcard(request: Request, id_card: IdCardInfo):
  70. image = base64_to_np(id_card.image)
  71. return m.predict(image, id_card.image_type)
  72. if __name__ == '__main__':
  73. import uvicorn
  74. import argparse
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument('--host', default='0.0.0.0')
  77. parser.add_argument('--port', default=8080)
  78. opt = parser.parse_args()
  79. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  80. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)