server.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # -*- coding: UTF-8 -*-
  2. import json
  3. from base64 import b64decode
  4. import base64
  5. import cv2
  6. import numpy as np
  7. from fastapi import FastAPI, Request
  8. from fastapi.middleware.cors import CORSMiddleware
  9. from pydantic import BaseModel
  10. from paddleocr import PaddleOCR, PPStructure
  11. from sx_utils.sxweb import *
  12. from sx_utils.sximage import *
  13. import threading
  14. import os
  15. import re
  16. from sx_utils.sx_log import *
  17. import paddleclas
  18. from cores.post_hander import *
  19. from cores.check_table import *
  20. format_print()
  21. # 初始化APP
  22. app = FastAPI()
  23. origins = ["*"]
  24. app.add_middleware(
  25. CORSMiddleware,
  26. allow_origins=origins,
  27. allow_credentials=True,
  28. allow_methods=["*"],
  29. allow_headers=["*"],
  30. )
  31. table_engine_lock = threading.Lock()
  32. table_engine = PPStructure(layout=False,
  33. table=True,
  34. use_gpu=True,
  35. show_log=True,
  36. use_angle_cls=True,
  37. # det_model_dir="models/det/det_table_v2",
  38. # det_model_dir="models/det/det_table_v3",
  39. # rec_model_dir="models/rec/rec_table_v1",
  40. table_model_dir="models/table/SLANet_911")
  41. cls_lock = threading.Lock()
  42. cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
  43. # # 普通表格
  44. # table_engine = PPStructure(layout=False,
  45. # table=True,
  46. # use_gpu=use_gpu,
  47. # show_log=True,
  48. # det_model_dir="models/det/det_table_v2",
  49. # rec_model_dir="./models/rec/rec_table_v1",
  50. # table_model_dir="models/table/SLANet_v2")
  51. #
  52. # # 长度较长表格
  53. # table_engine1 = PPStructure(layout=False,
  54. # table=True,
  55. # use_gpu=use_gpu,
  56. # show_log=True,
  57. # det_model_dir="models/det/det_table_v1",
  58. # rec_model_dir="./models/rec/rec_table_v1",
  59. # table_model_dir="./models/table/SLAnet_v1")
  60. #
  61. # # 针对某些特殊情况的补充模型
  62. # table_engine2 = PPStructure(layout=False,
  63. # table=True,
  64. # use_gpu=use_gpu,
  65. # show_log=True,
  66. # det_model_dir="models/det/det_table_v3",
  67. # rec_model_dir="./models/rec/rec_table_v1",
  68. # table_model_dir="./models/table/SLAnet_v1")
  69. #
  70. #
  71. #
  72. # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
  73. def cal_html_to_chs(html):
  74. res = []
  75. rows = re.split('<tr>', html)
  76. for row in rows:
  77. row = re.split('<td>', row)
  78. cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
  79. rec_str = ''.join(cells)
  80. for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
  81. rec_str = rec_str.replace(tag, '')
  82. res.append(rec_str)
  83. rec_res = ''.join(res).replace(' ', '')
  84. rec_res = re.split('<tdcolspan="\w+">', rec_res)
  85. rec_res = ''.join(rec_res).replace(' ', '')
  86. print(rec_res)
  87. return len(rec_res)
  88. def predict_cls(image, conf=0.8):
  89. try:
  90. cls_lock.acquire()
  91. result = cls_model.predict(image)
  92. finally:
  93. cls_lock.release()
  94. for res in result:
  95. score = res[0]['scores'][0]
  96. label_name = res[0]['label_names'][0]
  97. print(f"score: {score}, label_name: {label_name}")
  98. # print(conf)
  99. if score > conf:
  100. return int(label_name)
  101. return -1
  102. def rotate_to_zero(image, current_degree):
  103. # cv2.imwrite('1.jpg', image)
  104. current_degree = current_degree // 90
  105. if current_degree == 0:
  106. return image
  107. to_rotate = (4 - current_degree) - 1
  108. image = cv2.rotate(image, to_rotate)
  109. # cv2.imwrite('2.jpg', image)
  110. return image
  111. def get_zero_degree_image(img):
  112. step = 0.2
  113. for idx, i in enumerate([-1, 0, 1, 2]):
  114. if i >= 0:
  115. image = cv2.rotate(img.copy(), i)
  116. else:
  117. image = img.copy()
  118. conf = 0.8 - (idx * step)
  119. current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
  120. if current_degree != -1:
  121. img = rotate_to_zero(image, current_degree)
  122. break
  123. else:
  124. continue
  125. return img
  126. def table_res(im, ROTATE=-1):
  127. im = im.copy()
  128. # cv2.imwrite('before-rotate.jpg', im)
  129. # 获取正向图片
  130. img = get_zero_degree_image(im)
  131. # cv2.imwrite('after-rotate.jpg', img)
  132. try:
  133. table_engine_lock.acquire()
  134. res = table_engine(img)
  135. finally:
  136. table_engine_lock.release()
  137. html = res[0]['res']['html']
  138. return res, html
  139. class TableInfo(BaseModel):
  140. image: str
  141. det: str
  142. @app.get("/ping")
  143. def ping():
  144. return 'pong!!!!!!!!!'
  145. @app.post("/ocr_system/table")
  146. @web_try()
  147. def table(image: TableInfo):
  148. img = base64_to_np(image.image)
  149. res, html = table_res(img)
  150. # print(html)
  151. table = Table(html, img)
  152. if table.check_html():
  153. print('-=-=-=')
  154. res, html = table_res(table.img)
  155. if html:
  156. post_hander = PostHandler(html)
  157. # print(post_hander.format_predict_html)
  158. return {'html': post_hander.format_predict_html}
  159. else:
  160. raise Exception('无法识别')
  161. print('table system init success!')