Pārlūkot izejas kodu

Merge branch 'new' of http://gogsb.soaringnova.com/chenguilong/ocr-table into new

李时威 11 mēneši atpakaļ
vecāks
revīzija
09a09c2356
3 mainītis faili ar 98 papildinājumiem un 2 dzēšanām
  1. 80 0
      cores/check_table.py
  2. 6 1
      cores/post_decorators.py
  3. 12 1
      server.py

+ 80 - 0
cores/check_table.py

@@ -0,0 +1,80 @@
+import cv2
+import numpy as np
+
+
+class Table:
+    def __init__(self, html, img=[]):
+        self.img = img
+        self.html = html
+        self.html_arr = []
+        self.total = 0
+        self.empty = 0
+
+    def get_body(self):
+        try:
+            res = self.html.split('<tbody>')[1]
+        except Exception as r:
+            print('<tbody> 识别失败')
+            print(r)
+        try:
+            res = res.split('</tbody>')[0]
+        except Exception as r:
+            print('</tbody> 识别失败')
+            print(r)
+        return res
+
+    def get_tr(self):
+        str = self.get_body()
+        if len(str.split('<tr>')) > 1:
+            return str.split('<tr>')
+        else:
+            return []
+
+    def get_td(self):
+        if self.html_arr != []:
+            return
+        tr_list = self.get_tr()
+        for i in range(len(tr_list)):
+            if tr_list[i] == '':
+                continue
+            tr = tr_list[i].split('</td>')[:-1]
+            temp_list = []
+            for cell in tr:
+                if '<td colspan=\\"3\\">' in cell:
+                    temp_list.append(cell.split('<td colspan=\\"3\\">')[1])
+                if '<td>' in cell:
+                    temp_list.append(cell.split('<td>')[1])
+            self.html_arr.append(temp_list)
+
+    def get_empty(self):
+        self.get_td()
+        if self.total != 0:
+            return
+        for tr in self.html_arr:
+            for cell in tr:
+                self.total += 1
+                if cell == '':
+                    self.empty += 1
+
+    def change_green2white(self):
+        hsv = cv2.cvtColor(self.img, cv2.COLOR_BGR2HSV)
+        lower_green = np.array([35, 43, 46])
+        upper_green = np.array([77, 220, 255])
+        mask_green = cv2.inRange(hsv, lower_green, upper_green)
+        color = [248, 248, 255]
+        self.img[mask_green != 0] = color
+
+    def get_str(self):
+        str = ''
+        for tr in self.html_arr:
+            for cell in tr:
+                str+=cell
+        return str
+
+    def check_html(self):
+        self.get_empty()
+        html_str = self.get_str()
+        if (self.empty > 4 and self.empty > self.total // 4) or ('项目' in html_str and '每份' in html_str and '营养素参考值' in html_str and np.max([len(a) for a in self.html_arr])<3):
+            self.change_green2white()
+            return 1
+        return 0

+ 6 - 1
cores/post_decorators.py

@@ -50,7 +50,8 @@ def rule3_decorator(f, *args, **kwargs):
     ['患直质', '1.6克', '3%', '']
     ['脂扇', '1.1', '19%', '']
     ['碳水化合物', '勿18.2克', '6%', '']
-
+    ['能量.', '408千焦',	'5%']
+    ['——精', '2.9克']
     '''
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
@@ -58,6 +59,7 @@ def rule3_decorator(f, *args, **kwargs):
     predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
     predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
     predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
+    predict_line = [re.sub('——精', '——糖', s) for s in predict_line]
     return predict_line
 
 
@@ -79,6 +81,9 @@ def rule4_decorator(f, *args, **kwargs):
 
 @decorator
 def rule5_decorator(f, *args, **kwargs):
+    '''
+        predict_line = ['项目 ', '每份(70g)营养素参考值%', '']
+    '''
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     try:

+ 12 - 1
server.py

@@ -1,6 +1,7 @@
 # -*- coding: UTF-8 -*-
 import json
 from base64 import b64decode
+import base64
 
 import cv2
 import numpy as np
@@ -17,6 +18,7 @@ from sx_utils.sx_log import *
 import paddleclas
 
 from cores.post_hander import *
+from cores.check_table import *
 
 
 format_print()
@@ -111,17 +113,20 @@ def predict_cls(image, conf=0.8):
         score = res[0]['scores'][0]
         label_name = res[0]['label_names'][0]
         print(f"score: {score}, label_name: {label_name}")
+        # print(conf)
         if score > conf:
             return int(label_name)
     return -1
 
 
 def rotate_to_zero(image, current_degree):
+    # cv2.imwrite('1.jpg', image)
     current_degree = current_degree // 90
     if current_degree == 0:
         return image
     to_rotate = (4 - current_degree) - 1
     image = cv2.rotate(image, to_rotate)
+    # cv2.imwrite('2.jpg', image)
     return image
 
 
@@ -143,9 +148,10 @@ def get_zero_degree_image(img):
 
 def table_res(im, ROTATE=-1):
     im = im.copy()
+    # cv2.imwrite('before-rotate.jpg', im)
     # 获取正向图片
     img = get_zero_degree_image(im)
-    # cv2.imwrite('1.jpg', img)
+    # cv2.imwrite('after-rotate.jpg', img)
     try:
         table_engine_lock.acquire()
         res = table_engine(img)
@@ -169,6 +175,11 @@ def ping():
 def table(image: TableInfo):
     img = base64_to_np(image.image)
     res, html = table_res(img)
+    # print(html)
+    table = Table(html,img)
+    if table.check_html():
+        res, html = table_res(table.img)
+
 
     if html:
         post_hander = PostHandler(html)