Prechádzať zdrojové kódy

效果差重新识别逻辑修改

liweiquan 6 mesiacov pred
rodič
commit
ef2cfba8ba
2 zmenil súbory, kde vykonal 22 pridanie a 17 odobranie
  1. 7 10
      cores/check_table.py
  2. 15 7
      server.py

+ 7 - 10
cores/check_table.py

@@ -1,8 +1,6 @@
 import cv2
 import numpy as np
 
-hard_colors = [[[35, 43, 46], [77, 220, 255]]]
-
 class Table:
     def __init__(self, html, img=[]):
         """
@@ -72,7 +70,7 @@ class Table:
                 if cell == '':
                     self.empty += 1
 
-    def change_hard2white(self):
+    def change_hard2white(self, hard_color):
         """
         将图像中绿色区域修改为白色。
 
@@ -81,11 +79,10 @@ class Table:
         """
         color = [248, 248, 255]
         hsv = cv2.cvtColor(self.img, cv2.COLOR_BGR2HSV)
-        for hard_color in hard_colors:
-            lower_green = np.array(hard_color[0])
-            upper_green = np.array(hard_color[1])
-            mask_green = cv2.inRange(hsv, lower_green, upper_green)
-            self.img[mask_green != 0] = color
+        lower_green = np.array(hard_color[0])
+        upper_green = np.array(hard_color[1])
+        mask_green = cv2.inRange(hsv, lower_green, upper_green)
+        self.img[mask_green != 0] = color
 
     def get_str(self):
         """
@@ -100,7 +97,7 @@ class Table:
                 str += cell
         return str
 
-    def check_html(self):
+    def check_html(self, hard_color):
         """
         检查HTML表格的质量,如果识别效果不佳,则修改图像颜色。
 
@@ -113,6 +110,6 @@ class Table:
         # HTML字符串 html_str 中同时包含 '项目'、'每份' 和 '营养素参考值',并且在每一行的格子数中最大值小于3时。
         if (self.empty > 4 and self.empty > self.total // 4) or ('项目' in html_str and '每份' in html_str and '营养素参考值' in html_str and np.max([len(a) for a in self.html_arr]) < 3):
             print('识别效果不佳,改变图片颜色!')
-            self.change_hard2white()
+            self.change_hard2white(hard_color)
             return 1
         return 0

+ 15 - 7
server.py

@@ -29,11 +29,11 @@ app.add_middleware(
 table_engine_lock = threading.Lock()
 # 表格识别模型
 table_engine = CustomPPStructure(layout=False,
-                           table=True,
-                           use_gpu=True,
-                           show_log=True,
-                           use_angle_cls=True,
-                           table_model_dir="models/table/SLANet_latest")
+                                 table=True,
+                                 use_gpu=True,
+                                 show_log=True,
+                                 use_angle_cls=True,
+                                 table_model_dir="models/table/SLANet_latest")
 
 cls_lock = threading.Lock()
 
@@ -165,7 +165,7 @@ class TableInfo(BaseModel):
     image: str
     det: str
     prefer_cell: bool = Field(
-        default=False, 
+        default=False,
         description="是否使用cell_boxes替代dt_boxes"
     )
 
@@ -181,6 +181,9 @@ def ping():
     return 'pong!!!!!!!!!'
 
 
+hard_colors = [[[35, 43, 46], [77, 220, 255]]]
+
+
 @app.post("/ocr_system/table")
 @web_try()
 def table(info: TableInfo):
@@ -200,8 +203,13 @@ def table(info: TableInfo):
     # 创建Table实例
     table = Table(html, img)
     # 效果不好则重新识别
-    if table.check_html():
+    next_index = 0
+    while table.check_html(hard_colors[next_index]):
         res, html = table_res(table.img, prefer_cell=info.prefer_cell)
+        table = Table(html, img)
+        next_index += 1
+        if next_index >= len(hard_colors):
+            break
 
     if html:
         post_handler = PostHandler(html)