server.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # -*- coding: UTF-8 -*-
  2. from fastapi import FastAPI
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from pydantic import BaseModel
  5. from paddleocr import PPStructure
  6. from sx_utils.sxweb import *
  7. from sx_utils.sximage import *
  8. import threading
  9. from sx_utils.sx_log import *
  10. import paddleclas
  11. from cores.post_hander import *
  12. from cores.check_table import *
  13. format_print()
  14. # 初始化APP
  15. app = FastAPI()
  16. origins = ["*"]
  17. app.add_middleware(
  18. CORSMiddleware,
  19. allow_origins=origins,
  20. allow_credentials=True,
  21. allow_methods=["*"],
  22. allow_headers=["*"],
  23. )
  24. table_engine_lock = threading.Lock()
  25. # 表格识别模型
  26. table_engine = PPStructure(layout=False,
  27. table=True,
  28. use_gpu=True,
  29. show_log=True,
  30. use_angle_cls=True,
  31. det_model_dir="models/det/det_pse_quad",
  32. table_model_dir="models/table/SLANet_911")
  33. cls_lock = threading.Lock()
  34. cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
  35. # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
  36. def cal_html_to_chs(html):
  37. """
  38. 将HTML中的表格数据提取并合并为中文字符串。
  39. Parameters:
  40. html (str): 输入的HTML字符串。
  41. Returns:
  42. int: 合并后的中文字符串长度。
  43. """
  44. res = []
  45. rows = re.split('<tr>', html)
  46. for row in rows:
  47. row = re.split('<td>', row)
  48. cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
  49. rec_str = ''.join(cells)
  50. for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
  51. rec_str = rec_str.replace(tag, '')
  52. res.append(rec_str)
  53. rec_res = ''.join(res).replace(' ', '')
  54. rec_res = re.split('<tdcolspan="\w+">', rec_res)
  55. rec_res = ''.join(rec_res).replace(' ', '')
  56. return len(rec_res)
  57. def predict_cls(image, conf=0.8):
  58. """
  59. 使用分类模型对图像进行预测,并返回预测结果。
  60. Parameters:
  61. image (np.ndarray): 输入的图像数组。
  62. conf (float): 置信度阈值,默认为0.8。
  63. Returns:
  64. int: 预测结果的类别标签。
  65. """
  66. try:
  67. cls_lock.acquire()
  68. result = cls_model.predict(image)
  69. finally:
  70. cls_lock.release()
  71. for res in result:
  72. score = res[0]['scores'][0]
  73. label_name = res[0]['label_names'][0]
  74. print(f"score: {score}, label_name: {label_name}")
  75. if score > conf:
  76. return int(label_name)
  77. return -1
  78. def rotate_to_zero(image, current_degree):
  79. """
  80. 将图像旋转至零度方向。
  81. Parameters:
  82. image (np.ndarray): 输入的图像数组。
  83. current_degree (float): 当前的旋转角度。
  84. Returns:
  85. np.ndarray: 旋转后的图像数组。
  86. """
  87. current_degree = current_degree // 90
  88. if current_degree == 0:
  89. return image
  90. to_rotate = (4 - current_degree) - 1
  91. image = cv2.rotate(image, to_rotate)
  92. return image
  93. def get_zero_degree_image(img):
  94. """
  95. 获取经零度方向旋转后的图像。
  96. Parameters:
  97. img (np.ndarray): 输入的图像数组。
  98. Returns:
  99. np.ndarray: 经零度方向旋转后的图像数组。
  100. """
  101. step = 0.2
  102. for idx, i in enumerate([-1, 0, 1, 2]):
  103. if i >= 0:
  104. image = cv2.rotate(img.copy(), i)
  105. else:
  106. image = img.copy()
  107. conf = 0.8 - (idx * step)
  108. current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
  109. if current_degree != -1:
  110. img = rotate_to_zero(image, current_degree)
  111. break
  112. else:
  113. continue
  114. return img
  115. def table_res(im, ROTATE=-1):
  116. """
  117. 获取图像中表格的识别结果和HTML字符串。
  118. Parameters:
  119. im (np.ndarray): 输入的图像数组。
  120. ROTATE (int): 旋转角度,默认为-1。
  121. Returns:
  122. Tuple: 表格识别结果和HTML字符串。
  123. """
  124. im = im.copy()
  125. img = get_zero_degree_image(im)
  126. try:
  127. table_engine_lock.acquire()
  128. res = table_engine(img)
  129. finally:
  130. table_engine_lock.release()
  131. html = res[0]['res']['html']
  132. return res, html
  133. class TableInfo(BaseModel):
  134. image: str
  135. det: str
  136. @app.get("/ping")
  137. def ping():
  138. """
  139. 用于检查服务是否存活的端点。
  140. Returns:
  141. str: 返回pong表示服务存活。
  142. """
  143. return 'pong!!!!!!!!!'
  144. @app.post("/ocr_system/table")
  145. @web_try()
  146. def table(image: TableInfo):
  147. """
  148. 对图像中的表格进行识别并返回HTML字符串。
  149. Parameters:
  150. image (TableInfo): 输入的图像信息。
  151. Returns:
  152. dict: 包含HTML字符串的字典。
  153. """
  154. # 转换图片格式
  155. img = base64_to_np(image.image)
  156. # 进行表格识别
  157. res, html = table_res(img)
  158. # 创建Table实例
  159. table = Table(html, img)
  160. # 效果不好则重新识别
  161. if table.check_html():
  162. res, html = table_res(table.img)
  163. if html:
  164. post_handler = PostHandler(html)
  165. return {'html': post_handler.format_predict_html}
  166. else:
  167. raise Exception('无法识别')
  168. print('table system init success!')