server.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. from typing import Optional
  3. from fastapi import FastAPI, Request
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from paddleocr import PaddleOCR
  6. from pydantic import BaseModel
  7. from core.ocr import BankOcr
  8. from core.direction import AngleDetector
  9. from s_utils.image import *
  10. from s_utils.time import timeit
  11. from s_utils.web import web_try
  12. app = FastAPI()
  13. origins = ["*"]
  14. app.add_middleware(
  15. CORSMiddleware,
  16. allow_origins=origins,
  17. allow_credentials=True,
  18. allow_methods=["*"],
  19. allow_headers=["*"],
  20. )
  21. use_gpu = os.getenv('USE_CUDA') == 'gpu'
  22. print(f'use gpu: {use_gpu}')
  23. ocr = PaddleOCR(
  24. det_model_dir="./models/models_0/det",
  25. rec_model_dir="./models/models_0/rec",
  26. cls_model_dir="./models/models_0/cls",
  27. rec_char_dict_path="./ppocr_keys_bank.txt",
  28. det_db_unclip_ratio=2.5,
  29. ocr_version='PP-OCRv2',
  30. det_db_thresh=0.3,
  31. det_db_box_thresh=0.6,
  32. use_angle_cls=True,
  33. use_gpu=use_gpu,
  34. use_space_char=False,
  35. warmup=True
  36. )
  37. origin_ocr = PaddleOCR(
  38. det_model_dir="./models/models_1/det",
  39. rec_model_dir="./models/models_0/rec",
  40. cls_model_dir="./models/models_1/cls",
  41. rec_char_dict_path="./ppocr_keys_bank.txt",
  42. det_db_unclip_ratio=2.5,
  43. ocr_version='PP-OCRv2',
  44. det_db_thresh=0.3,
  45. det_db_box_thresh=0.6,
  46. use_angle_cls=True,
  47. use_gpu=use_gpu,
  48. use_space_char=False,
  49. warmup=True
  50. )
  51. ad = AngleDetector(origin_ocr)
  52. m = BankOcr(ocr, ad)
  53. @app.get("/ping")
  54. def ping():
  55. return "pong!"
  56. class IdCardInfo(BaseModel):
  57. image: str
  58. fn: Optional[str]
  59. @app.post("/ocr_system/bankcard")
  60. @timeit
  61. @web_try()
  62. def detect(request: Request, id_card: IdCardInfo):
  63. image = base64_to_np(id_card.image)
  64. return m.predict(image)
  65. if __name__ == '__main__':
  66. import uvicorn
  67. import argparse
  68. parser = argparse.ArgumentParser()
  69. parser.add_argument('--host', default='0.0.0.0')
  70. parser.add_argument('--port', default=8080)
  71. opt = parser.parse_args()
  72. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  73. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True, workers=1)