Ver Fonte

fix: 加入方向判断

Zhang Li há 1 ano atrás
pai
commit
3cf46d5612
3 ficheiros alterados com 58 adições e 17 exclusões
  1. 1 0
      environment.yml
  2. 51 13
      server.py
  3. 6 4
      testing/table_test.py

+ 1 - 0
environment.yml

@@ -22,6 +22,7 @@ dependencies:
       - cpca
       - cpca
       - uvicorn
       - uvicorn
       - protobuf==3.20.1
       - protobuf==3.20.1
+      - paddleclas==2.5.1
       - -i https://mirror.baidu.com/pypi/simple
       - -i https://mirror.baidu.com/pypi/simple
       - paddlepaddle  # gpu==2.3.0.post110
       - paddlepaddle  # gpu==2.3.0.post110
       - -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
       - -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html

+ 51 - 13
server.py

@@ -13,6 +13,8 @@ import threading
 import os
 import os
 import re
 import re
 from sx_utils.sx_log import *
 from sx_utils.sx_log import *
+import paddleclas
+
 
 
 
 
 
 
@@ -41,6 +43,10 @@ table_engine = PPStructure(layout=False,
                            # rec_model_dir="models/rec/rec_table_v1",
                            # rec_model_dir="models/rec/rec_table_v1",
                            table_model_dir="models/table/SLANet_905")
                            table_model_dir="models/table/SLANet_905")
 
 
+
+cls_lock = threading.Lock()
+
+cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
 # # 普通表格
 # # 普通表格
 # table_engine = PPStructure(layout=False,
 # table_engine = PPStructure(layout=False,
 #                            table=True,
 #                            table=True,
@@ -91,13 +97,54 @@ def cal_html_to_chs(html):
     return len(rec_res)
     return len(rec_res)
 
 
 
 
+
+def predict_cls(image, conf=0.8):
+    try:
+        cls_lock.acquire()
+        result = cls_model.predict(image)
+    finally:
+        cls_lock.release()
+    for res in result:
+        score = res[0]['scores'][0]
+        label_name = res[0]['label_names'][0]
+        print(f'score: {score}, label_name: {label_name}')
+        if score > conf:
+            return int(label_name)
+    return -1
+
+
+def rotate_to_zero(image, current_degree):
+    current_degree = current_degree // 90
+    if current_degree == 0:
+        return image
+    to_rotate = (4 - current_degree) - 1
+    image = cv2.rotate(image, to_rotate)
+    return image
+
+
+def get_zero_degree_image(img):
+    step = 0.2
+    for idx, i in enumerate([-1, 0, 1, 2]):
+        if i >= 0:
+            image = cv2.rotate(img.copy(), i)
+        else:
+            image = img.copy()
+        conf = 0.8 - (idx * step)
+        current_degree = predict_cls(image, conf)  # 0 90  180 270 -1 识别不出来
+        if current_degree != -1:
+            img = rotate_to_zero(image, current_degree)
+            break
+        else:
+            continue
+    return img
+
 def table_res(im, ROTATE=-1):
 def table_res(im, ROTATE=-1):
     im = im.copy()
     im = im.copy()
-    if ROTATE >= 0:
-        im = cv2.rotate(im, ROTATE)
+    # 获取正向图片
+    img = get_zero_degree_image(im)
     try:
     try:
         table_engine_lock.acquire()
         table_engine_lock.acquire()
-        res = table_engine(im)
+        res = table_engine(img)
     finally:
     finally:
         table_engine_lock.release()
         table_engine_lock.release()
     html = res[0]['res']['html']
     html = res[0]['res']['html']
@@ -117,16 +164,7 @@ def ping():
 @web_try()
 @web_try()
 def table(image: TableInfo):
 def table(image: TableInfo):
     img = base64_to_np(image.image)
     img = base64_to_np(image.image)
-    res_len = 0
-    res = None
-    for i in [-1, 0, 1, 2]:
-        _res, html = table_res(img, i)
-        print(html)
-        _res_len = cal_html_to_chs(html)
-        if _res_len > res_len:
-            res = _res
-            res_len = _res_len
-
+    res, html = table_res(img)
     if res:
     if res:
         return res[0]['res']
         return res[0]['res']
     else:
     else:

+ 6 - 4
testing/table_test.py

@@ -23,13 +23,15 @@ class TestTableOcr(unittest.TestCase):
     #     self.assertEqual('"pong!"', pong.text, 'Not work')
     #     self.assertEqual('"pong!"', pong.text, 'Not work')
 
 
     def test_table_01(self):
     def test_table_01(self):
-        print("被旋转的图片")
+        print("270")
         fn = Path(__file__).parent / '../images/01.jpeg'
         fn = Path(__file__).parent / '../images/01.jpeg'
         res = send_request(fn)
         res = send_request(fn)
         # self.assertEqual('000', res['status'], 'Not work')
         # self.assertEqual('000', res['status'], 'Not work')
-    #
-    # def test_table_02(self):
-    #     res = send_request('./images/02.jpg')
+
+    def test_table_02(self):
+        print("270")
+        fn = Path(__file__).parent / '../images/02.jpg'
+        res = send_request(fn)
     #     self.assertEqual('000', res['status'], 'Not work')
     #     self.assertEqual('000', res['status'], 'Not work')
     #
     #
     # def test_table_03(self):
     # def test_table_03(self):