server.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import json
  2. from base64 import b64decode
  3. import cv2
  4. import numpy as np
  5. from fastapi import FastAPI, Request
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from pydantic import BaseModel
  8. from paddleocr import PaddleOCR, PPStructure
  9. from sx_utils.sxweb import *
  10. from sx_utils.sximage import *
  11. import os
  12. # 初始化app
  13. app = FastAPI()
  14. origins = ["*"]
  15. app.add_middleware(
  16. CORSMiddleware,
  17. allow_origins=origins,
  18. allow_credentials=True,
  19. allow_methods=["*"],
  20. allow_headers=["*"],
  21. )
  22. use_gpu = os.getenv('USE_CUDA') == 'gpu'
  23. print(f'use gpu: {use_gpu}')
  24. # 普通表格
  25. table_engine = PPStructure(layout=False,
  26. table=True,
  27. use_gpu=use_gpu,
  28. show_log=True,
  29. det_model_dir="models/det/det_table_v2",
  30. rec_model_dir="./models/rec/rec_table_v1",
  31. table_model_dir="models/table/SLAnet_v1")
  32. # 长度较长表格
  33. table_engine1 = PPStructure(layout=False,
  34. table=True,
  35. use_gpu=use_gpu,
  36. show_log=True,
  37. det_model_dir="models/det/det_table_v1",
  38. rec_model_dir="./models/rec/rec_table_v1",
  39. table_model_dir="./models/table/SLAnet_v1")
  40. # 针对某些特殊情况的补充模型
  41. table_engine2 = PPStructure(layout=False,
  42. table=True,
  43. use_gpu=use_gpu,
  44. show_log=True,
  45. det_model_dir="models/det/det_table_v3",
  46. rec_model_dir="./models/rec/rec_table_v1",
  47. table_model_dir="./models/table/SLAnet_v1")
  48. class TableInfo(BaseModel):
  49. image: str
  50. det: str
  51. @app.get("/ping")
  52. def ping():
  53. return 'pong!'
  54. @app.post("/ocr_system/table")
  55. @web_try()
  56. def table(image: TableInfo):
  57. img = base64_to_np(image.image)
  58. if image.det == 'no':
  59. res = table_engine(img)
  60. elif image.det == 'yes':
  61. res = table_engine1(img)
  62. elif image.det == 'spe':
  63. res = table_engine2(img)
  64. return res[0]['res']
  65. if __name__ == '__main__':
  66. import uvicorn
  67. import argparse
  68. parser = argparse.ArgumentParser()
  69. parser.add_argument('--host', default='0.0.0.0')
  70. parser.add_argument('--port', default=8080)
  71. opt = parser.parse_args()
  72. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  73. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)