server.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import io
  2. import json
  3. import re
  4. from fastapi import FastAPI, Request, File, UploadFile, Body
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from sx_utils.sximage import *
  7. from sx_utils.sxtime import sxtimeit
  8. from sx_utils.sxweb import web_try
  9. import requests
  10. from PIL import Image
  11. from pydantic import BaseModel
  12. import sys
  13. import logging
  14. import os
  15. import cv2
  16. from paddleocr import PaddleOCR
  17. logger = logging.getLogger('log')
  18. logger.setLevel(logging.DEBUG)
  19. # 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
  20. while logger.hasHandlers():
  21. for i in logger.handlers:
  22. logger.removeHandler(i)
  23. # file log 写入文件配置
  24. formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') # 日志的格式
  25. # 本地运行时,这部分需注释
  26. # fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8') # 日志文件路径文件名称,编码格式
  27. # fh.setLevel(logging.DEBUG) # 日志打印级别
  28. # fh.setFormatter(formatter)
  29. # logger.addHandler(fh)
  30. # console log 控制台输出控制
  31. ch = logging.StreamHandler(sys.stdout)
  32. ch.setLevel(logging.DEBUG)
  33. ch.setFormatter(formatter)
  34. logger.addHandler(ch)
  35. app = FastAPI()
  36. origins = ["*"]
  37. app.add_middleware(
  38. CORSMiddleware,
  39. allow_origins=origins,
  40. allow_credentials=True,
  41. allow_methods=["*"],
  42. allow_headers=["*"],
  43. )
  44. use_gpu = False
  45. if os.getenv('USE_CUDA') == 'gpu':
  46. use_gpu = True
  47. logger.info(f"->是否使用GPU:{use_gpu}")
  48. ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./table_rec_infer/",det_model_dir="./table_det_infer/",cls_model_dir="table_cls_infer",lang="ch")
  49. @app.get("/ping")
  50. def ping():
  51. return "pong!"
  52. class ImageListInfo(BaseModel):
  53. images: list
  54. img_type: str
  55. @app.post("/ocr_system/paddle")
  56. @sxtimeit
  57. @web_try()
  58. def rotate_bound_white_bg(self, image, angle):
  59. (h, w) = image.shape[:2]
  60. (cX, cY) = (w // 2, h // 2)
  61. M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
  62. cos = np.abs(M[0, 0])
  63. sin = np.abs(M[0, 1])
  64. nW = int((h * sin) + (w * cos))
  65. nH = int((h * cos) + (w * sin))
  66. M[0, 2] += (nW / 2) - cX
  67. M[1, 2] += (nH / 2) - cY
  68. return cv2.warpAffine(image, M, (nW, nH), borderValue=(255, 255, 255))
  69. class GetImageRotation(object):
  70. def __init__(self):
  71. self.ocr = PaddleOCR(use_angle_cls=True)
  72. self.ocr_angle = PaddleOCR(use_angle_cls=True)
  73. def get_real_rotation_when_null_rect(self, rect_list):
  74. w_div_h_sum = 0
  75. count = 0
  76. for rect in rect_list:
  77. p0 = rect[0]
  78. p1 = rect[1]
  79. p2 = rect[2]
  80. p3 = rect[3]
  81. width = abs(p1[0] - p0[0])
  82. height = abs(p3[1] - p0[1])
  83. w_div_h = width / height
  84. if abs(w_div_h - 1.0) < 0.5:
  85. count += 1
  86. continue
  87. w_div_h_sum += w_div_h
  88. length = len(rect_list) - count
  89. if length == 0:
  90. length = 1
  91. if w_div_h_sum / length >= 1.5:
  92. return 1
  93. else:
  94. return 0
  95. def get_real_rotation_flag(self, rect_list):
  96. ret_rect = []
  97. w_div_h_mean = 0
  98. real_rect_count = 0
  99. rect_big_list = []
  100. rect_small_list = []
  101. w_div_h_sum_big = []
  102. w_div_h_sum_small = []
  103. for rect in rect_list:
  104. p0 = rect[0]
  105. p1 = rect[1]
  106. p2 = rect[2]
  107. p3 = rect[3]
  108. width = abs(p1[0] - p0[0])
  109. height = abs(p3[1] - p0[1])
  110. w_div_h = width / height
  111. if 5 <= w_div_h <= 25:
  112. real_rect_count +=1
  113. rect_big_list.append(rect)
  114. w_div_h_sum_big.append(w_div_h)
  115. if 0.04 <= w_div_h <= 0.2:
  116. real_rect_count -=1
  117. rect_small_list.append(rect)
  118. w_div_h_sum_small.append(w_div_h)
  119. if real_rect_count > 0:
  120. ret_rect = rect_big_list
  121. w_div_h_mean = np.mean(w_div_h_sum_big)
  122. else:
  123. ret_rect = rect_small_list
  124. w_div_h_mean = np.mean(w_div_h_sum_small)
  125. if w_div_h_mean >= 1.5:
  126. return 1, ret_rect
  127. else:
  128. return 0, ret_rect
  129. def crop_image(self, rect, image):
  130. p0 = rect[0]
  131. p1 = rect[1]
  132. p2 = rect[2]
  133. p3 = rect[3]
  134. crop = image[int(p0[1]):int(p2[1]), int(p0[0]):int(p2[0])]
  135. # crop_image = Image.fromarray(crop)
  136. return crop
  137. def get_img_real_angle(self, img):
  138. ret_angle = 0
  139. image = img
  140. # ocr = PaddleOCR(use_angle_cls=True)
  141. # angle_cls = ocr.ocr(img_path, det=False, rec=False, cls=True)
  142. rect_list = self.ocr.ocr(image, rec=False)
  143. if rect_list != [[]]:
  144. except_flag = False
  145. try:
  146. real_angle_flag, rect_good = self.get_real_rotation_flag(
  147. rect_list)
  148. rect_crop = choice(rect_good)
  149. # rect_crop = rect_good[0]
  150. image_crop = self.crop_image(rect_crop, image)
  151. # ocr_angle = PaddleOCR(use_angle_cls=True)
  152. angle_cls = self.ocr_angle.ocr(
  153. image_crop, det=False, rec=False, cls=True)
  154. except:
  155. except_flag = True
  156. real_angle_flag = self.get_real_rotation_when_null_rect(
  157. rect_list)
  158. # ocr_angle = PaddleOCR(use_angle_cls=True)
  159. angle_cls = self.ocr_angle.ocr(
  160. image, det=False, rec=False, cls=True)
  161. else:
  162. return 0
  163. if angle_cls[0][0] == '0':
  164. if real_angle_flag:
  165. ret_angle = 0
  166. else:
  167. ret_angle = 270
  168. if not except_flag:
  169. anticlockwise_90 = rotate_bound_white_bg(image_crop, 90)
  170. angle_cls = self.ocr_angle.ocr(anticlockwise_90, det=False, rec=False, cls=True)
  171. if angle_cls[0][0] == '0':
  172. ret_angle = 270
  173. if angle_cls[0][0] == '180':
  174. ret_angle = 90
  175. if angle_cls[0][0] == '180':
  176. if real_angle_flag:
  177. ret_angle = 180
  178. else:
  179. ret_angle = 90
  180. return ret_angle
  181. def paddle(request: Request,info: ImageListInfo):
  182. logger.info(f"->图片数量:{len(info.images)}")
  183. res_list = []
  184. for b_img in info.images:
  185. img = base64_to_np(b_img)
  186. route=GetImageRotation()
  187. route2=route.get_img_real_angle(img)
  188. if route2==90 or route2== 270:
  189. img=im.transpose(img.ROTATE_90)
  190. result=ocr.ocr(img,cls=True)
  191. r_list = []
  192. for text_list in result:
  193. if len(text_list) >= 1:
  194. data = {}
  195. data["confidence"]= text_list[1][1]
  196. data["text"] = text_list[1][0]
  197. data["type"] = info.img_type
  198. data["text_region"]= text_list[0]
  199. r_list.append(data)
  200. res_list.append(r_list)
  201. return res_list
  202. if __name__ == '__main__':
  203. import uvicorn
  204. import argparse
  205. parser = argparse.ArgumentParser()
  206. parser.add_argument('--host', default='0.0.0.0')
  207. parser.add_argument('--port', default=8080)
  208. opt = parser.parse_args()
  209. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  210. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)