server.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # -*- coding: UTF-8 -*-
  2. from fastapi import FastAPI
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from pydantic import BaseModel, Field
  5. from cores.table_engine import CustomPPStructure
  6. from sx_utils.sxweb import *
  7. from sx_utils.sximage import *
  8. import threading
  9. from sx_utils.sx_log import *
  10. import paddleclas
  11. from cores.post_hander import *
  12. from cores.check_table import *
  13. format_print()
  14. # 初始化APP
  15. app = FastAPI()
  16. origins = ["*"]
  17. app.add_middleware(
  18. CORSMiddleware,
  19. allow_origins=origins,
  20. allow_credentials=True,
  21. allow_methods=["*"],
  22. allow_headers=["*"],
  23. )
  24. table_engine_lock = threading.Lock()
  25. # 表格识别模型
  26. table_engine = CustomPPStructure(layout=False,
  27. table=True,
  28. use_gpu=True,
  29. show_log=True,
  30. use_angle_cls=True,
  31. table_model_dir="models/table/SLANet_911")
  32. cls_lock = threading.Lock()
  33. cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
  34. # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
  35. def cal_html_to_chs(html):
  36. """
  37. 将HTML中的表格数据提取并合并为中文字符串。
  38. Parameters:
  39. html (str): 输入的HTML字符串。
  40. Returns:
  41. int: 合并后的中文字符串长度。
  42. """
  43. res = []
  44. rows = re.split('<tr>', html)
  45. for row in rows:
  46. row = re.split('<td>', row)
  47. cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
  48. rec_str = ''.join(cells)
  49. for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
  50. rec_str = rec_str.replace(tag, '')
  51. res.append(rec_str)
  52. rec_res = ''.join(res).replace(' ', '')
  53. rec_res = re.split('<tdcolspan="\w+">', rec_res)
  54. rec_res = ''.join(rec_res).replace(' ', '')
  55. return len(rec_res)
  56. def predict_cls(image, conf=0.8):
  57. """
  58. 使用分类模型对图像进行预测,并返回预测结果。
  59. Parameters:
  60. image (np.ndarray): 输入的图像数组。
  61. conf (float): 置信度阈值,默认为0.8。
  62. Returns:
  63. int: 预测结果的类别标签。
  64. """
  65. try:
  66. cls_lock.acquire()
  67. result = cls_model.predict(image)
  68. finally:
  69. cls_lock.release()
  70. for res in result:
  71. score = res[0]['scores'][0]
  72. label_name = res[0]['label_names'][0]
  73. print(f"score: {score}, label_name: {label_name}")
  74. if score > conf:
  75. return int(label_name)
  76. return -1
  77. def rotate_to_zero(image, current_degree):
  78. """
  79. 将图像旋转至零度方向。
  80. Parameters:
  81. image (np.ndarray): 输入的图像数组。
  82. current_degree (float): 当前的旋转角度。
  83. Returns:
  84. np.ndarray: 旋转后的图像数组。
  85. """
  86. current_degree = current_degree // 90
  87. if current_degree == 0:
  88. return image
  89. to_rotate = (4 - current_degree) - 1
  90. image = cv2.rotate(image, to_rotate)
  91. return image
  92. def get_zero_degree_image(img):
  93. """
  94. 获取经零度方向旋转后的图像。
  95. Parameters:
  96. img (np.ndarray): 输入的图像数组。
  97. Returns:
  98. np.ndarray: 经零度方向旋转后的图像数组。
  99. """
  100. step = 0.2
  101. for idx, i in enumerate([-1, 0, 1, 2]):
  102. if i >= 0:
  103. image = cv2.rotate(img.copy(), i)
  104. else:
  105. image = img.copy()
  106. conf = 0.8 - (idx * step)
  107. current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
  108. if current_degree != -1:
  109. img = rotate_to_zero(image, current_degree)
  110. break
  111. else:
  112. continue
  113. return img
  114. def table_res(im, ROTATE=-1, prefer_cell=False):
  115. """
  116. 获取图像中表格的识别结果和HTML字符串。
  117. Parameters:
  118. im (np.ndarray): 输入的图像数组。
  119. ROTATE (int): 旋转角度,默认为-1。
  120. prefer_cell (bool): 是否使用cell_boxes替代dt_boxes,默认为否。
  121. Returns:
  122. Tuple: 表格识别结果和HTML字符串。
  123. """
  124. im = im.copy()
  125. img = get_zero_degree_image(im)
  126. try:
  127. table_engine_lock.acquire()
  128. res = table_engine(img, prefer_table_cell_boxes=prefer_cell)
  129. finally:
  130. table_engine_lock.release()
  131. html = res[0]['res']['html']
  132. return res, html
  133. class TableInfo(BaseModel):
  134. image: str
  135. det: str
  136. prefer_cell: bool = Field(
  137. default=False,
  138. description="是否使用cell_boxes替代dt_boxes"
  139. )
  140. @app.get("/ping")
  141. def ping():
  142. """
  143. 用于检查服务是否存活的端点。
  144. Returns:
  145. str: 返回pong表示服务存活。
  146. """
  147. return 'pong!!!!!!!!!'
  148. @app.post("/ocr_system/table")
  149. @web_try()
  150. def table(info: TableInfo):
  151. """
  152. 对图像中的表格进行识别并返回HTML字符串。
  153. Parameters:
  154. info (TableInfo): 输入的图像信息。
  155. Returns:
  156. dict: 包含HTML字符串的字典。
  157. """
  158. # 转换图片格式
  159. img = base64_to_np(info.image)
  160. # 进行表格识别
  161. res, html = table_res(img, prefer_cell=info.prefer_cell)
  162. # 创建Table实例
  163. table = Table(html, img)
  164. # 效果不好则重新识别
  165. if table.check_html():
  166. res, html = table_res(table.img, prefer_cell=info.prefer_cell)
  167. if html:
  168. post_handler = PostHandler(html)
  169. return {'html': post_handler.format_predict_html}
  170. else:
  171. raise Exception('无法识别')
  172. print('table system init success!')