server.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # -*- coding: UTF-8 -*-
  2. import json
  3. from base64 import b64decode
  4. import cv2
  5. import numpy as np
  6. from fastapi import FastAPI, Request
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from pydantic import BaseModel
  9. from paddleocr import PaddleOCR, PPStructure
  10. from sx_utils.sxweb import *
  11. from sx_utils.sximage import *
  12. import threading
  13. import os
  14. import re
  15. from sx_utils.sx_log import *
  16. import paddleclas
  17. from cores.post_hander import *
  18. format_print()
  19. # 初始化APP
  20. app = FastAPI()
  21. origins = ["*"]
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=origins,
  25. allow_credentials=True,
  26. allow_methods=["*"],
  27. allow_headers=["*"],
  28. )
  29. table_engine_lock = threading.Lock()
  30. table_engine = PPStructure(layout=False,
  31. table=True,
  32. use_gpu=True,
  33. show_log=True,
  34. use_angle_cls=True,
  35. # det_model_dir="models/det/det_table_v2",
  36. # det_model_dir="models/det/det_table_v3",
  37. # rec_model_dir="models/rec/rec_table_v1",
  38. table_model_dir="models/table/SLANet_911")
  39. cls_lock = threading.Lock()
  40. cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
  41. # # 普通表格
  42. # table_engine = PPStructure(layout=False,
  43. # table=True,
  44. # use_gpu=use_gpu,
  45. # show_log=True,
  46. # det_model_dir="models/det/det_table_v2",
  47. # rec_model_dir="./models/rec/rec_table_v1",
  48. # table_model_dir="models/table/SLANet_v2")
  49. #
  50. # # 长度较长表格
  51. # table_engine1 = PPStructure(layout=False,
  52. # table=True,
  53. # use_gpu=use_gpu,
  54. # show_log=True,
  55. # det_model_dir="models/det/det_table_v1",
  56. # rec_model_dir="./models/rec/rec_table_v1",
  57. # table_model_dir="./models/table/SLAnet_v1")
  58. #
  59. # # 针对某些特殊情况的补充模型
  60. # table_engine2 = PPStructure(layout=False,
  61. # table=True,
  62. # use_gpu=use_gpu,
  63. # show_log=True,
  64. # det_model_dir="models/det/det_table_v3",
  65. # rec_model_dir="./models/rec/rec_table_v1",
  66. # table_model_dir="./models/table/SLAnet_v1")
  67. #
  68. #
  69. #
  70. # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
  71. def cal_html_to_chs(html):
  72. res = []
  73. rows = re.split('<tr>', html)
  74. for row in rows:
  75. row = re.split('<td>', row)
  76. cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
  77. rec_str = ''.join(cells)
  78. for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
  79. rec_str = rec_str.replace(tag, '')
  80. res.append(rec_str)
  81. rec_res = ''.join(res).replace(' ', '')
  82. rec_res = re.split('<tdcolspan="\w+">', rec_res)
  83. rec_res = ''.join(rec_res).replace(' ', '')
  84. print(rec_res)
  85. return len(rec_res)
  86. def predict_cls(image, conf=0.8):
  87. try:
  88. cls_lock.acquire()
  89. result = cls_model.predict(image)
  90. finally:
  91. cls_lock.release()
  92. for res in result:
  93. score = res[0]['scores'][0]
  94. label_name = res[0]['label_names'][0]
  95. print(f"score: {score}, label_name: {label_name}")
  96. if score > conf:
  97. return int(label_name)
  98. return -1
  99. def rotate_to_zero(image, current_degree):
  100. current_degree = current_degree // 90
  101. if current_degree == 0:
  102. return image
  103. to_rotate = (4 - current_degree) - 1
  104. image = cv2.rotate(image, to_rotate)
  105. return image
  106. def get_zero_degree_image(img):
  107. step = 0.2
  108. for idx, i in enumerate([-1, 0, 1, 2]):
  109. if i >= 0:
  110. image = cv2.rotate(img.copy(), i)
  111. else:
  112. image = img.copy()
  113. conf = 0.8 - (idx * step)
  114. current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
  115. if current_degree != -1:
  116. img = rotate_to_zero(image, current_degree)
  117. break
  118. else:
  119. continue
  120. return img
  121. def table_res(im, ROTATE=-1):
  122. im = im.copy()
  123. # 获取正向图片
  124. img = get_zero_degree_image(im)
  125. # cv2.imwrite('1.jpg', img)
  126. try:
  127. table_engine_lock.acquire()
  128. res = table_engine(img)
  129. finally:
  130. table_engine_lock.release()
  131. html = res[0]['res']['html']
  132. return res, html
  133. class TableInfo(BaseModel):
  134. image: str
  135. det: str
  136. @app.get("/ping")
  137. def ping():
  138. return 'pong!!!!!!!!!'
  139. @app.post("/ocr_system/table")
  140. @web_try()
  141. def table(image: TableInfo):
  142. img = base64_to_np(image.image)
  143. res, html = table_res(img)
  144. if html:
  145. post_hander = PostHandler(html)
  146. return {'html': post_hander.format_predict_html}
  147. else:
  148. raise Exception('无法识别')
  149. print('table system init success!')