server.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import io
  2. import json
  3. import re
  4. from fastapi import FastAPI, Request, File, UploadFile, Body
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. import requests
  10. from PIL import Image
  11. from pydantic import BaseModel
  12. import sys
  13. import logging
  14. import os
  15. import cv2
  16. from paddleocr import PaddleOCR
  17. logger = logging.getLogger('log')
  18. logger.setLevel(logging.DEBUG)
  19. # 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
  20. while logger.hasHandlers():
  21. for i in logger.handlers:
  22. logger.removeHandler(i)
  23. # file log 写入文件配置
  24. formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') # 日志的格式
  25. # 本地运行时,这部分需注释
  26. # fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8') # 日志文件路径文件名称,编码格式
  27. # fh.setLevel(logging.DEBUG) # 日志打印级别
  28. # fh.setFormatter(formatter)
  29. # logger.addHandler(fh)
  30. # console log 控制台输出控制
  31. ch = logging.StreamHandler(sys.stdout)
  32. ch.setLevel(logging.DEBUG)
  33. ch.setFormatter(formatter)
  34. logger.addHandler(ch)
  35. app = FastAPI()
  36. origins = ["*"]
  37. app.add_middleware(
  38. CORSMiddleware,
  39. allow_origins=origins,
  40. allow_credentials=True,
  41. allow_methods=["*"],
  42. allow_headers=["*"],
  43. )
  44. use_gpu = False
  45. if os.getenv('USE_CUDA') == 'gpu':
  46. use_gpu = True
  47. logger.info(f"->是否使用GPU:{use_gpu}")
  48. ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./table_rec_infer/",det_model_dir="./table_det_infer/",cls_model_dir="table_cls_infer",lang="ch")
  49. @app.get("/ping")
  50. def ping():
  51. return "pong!"
  52. class ImageListInfo(BaseModel):
  53. images: list
  54. img_type: str
  55. @app.post("/ocr_system/paddle")
  56. @sxtimeit
  57. @web_try()
  58. def paddle(request: Request,info: ImageListInfo):
  59. logger.info(f"->图片数量:{len(info.images)}")
  60. res_list = []
  61. for b_img in info.images:
  62. img = base64_to_np(b_img)
  63. result=ocr.ocr(img,cls=True)
  64. r_list = []
  65. for text_list in result:
  66. if len(text_list) >= 1:
  67. data = {}
  68. data["confidence"]= text_list[1][1]
  69. data["text"] = text_list[1][0]
  70. data["type"] = info.img_type
  71. data["text_region"]= text_list[0]
  72. r_list.append(data)
  73. res_list.append(r_list)
  74. return res_list
  75. if __name__ == '__main__':
  76. import uvicorn
  77. import argparse
  78. parser = argparse.ArgumentParser()
  79. parser.add_argument('--host', default='0.0.0.0')
  80. parser.add_argument('--port', default=8080)
  81. opt = parser.parse_args()
  82. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  83. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)