Browse Source

添加注释

liweiquan 7 months ago
parent
commit
aeb0dca2fb
5 changed files with 162 additions and 104 deletions
  1. 46 20
      cores/check_table.py
  2. 37 29
      cores/post_decorators.py
  3. 5 0
      cores/post_hander.py
  4. 72 55
      server.py
  5. 2 0
      sx_utils/sximage.py

+ 46 - 20
cores/check_table.py

@@ -4,27 +4,26 @@ import numpy as np
 
 class Table:
     def __init__(self, html, img=[]):
+        """
+        表格类的初始化函数。
+
+        Parameters:
+            html (str): 输入的HTML字符串。
+            img (List): 输入的图像数组,默认为空列表。
+        """
         self.img = img
         self.html = html
-        self.html_arr = []
-        self.total = 0
-        self.empty = 0
-
-    # def get_body(self):
-    #     try:
-    #         res = self.html.split('<tbody>')[1]
-    #     except Exception as r:
-    #         print('<tbody> 识别失败')
-    #         print(r)
-    #     try:
-    #         res = res.split('</tbody>')[0]
-    #     except Exception as r:
-    #         print('</tbody> 识别失败')
-    #         print(r)
-    #     return res
+        self.html_arr = []  # 存储HTML解析后的表格内容
+        self.total = 0  # 表格单元总数
+        self.empty = 0  # 空白表格单元数
 
     def get_tr(self):
-        # str = self.get_body()
+        """
+        从HTML中提取并返回表格行。
+
+        Returns:
+            List: 提取的表格行列表。
+        """
         str = self.html
         if len(str.split('<tr>')) > 1:
             return str.split('<tr>')[1:]
@@ -32,6 +31,12 @@ class Table:
             return []
 
     def get_td(self):
+        """
+        从HTML中提取并存储表格单元。
+
+        Returns:
+            None
+        """
         if self.html_arr != []:
             return
         tr_list = self.get_tr()
@@ -51,6 +56,12 @@ class Table:
             self.html_arr.append(temp_list)
 
     def get_empty(self):
+        """
+        统计表格中的空白单元格数量和总单元格数量。
+
+        Returns:
+            None
+        """
         self.get_td()
         if self.total != 0:
             return
@@ -61,6 +72,12 @@ class Table:
                     self.empty += 1
 
     def change_green2white(self):
+        """
+        将图像中绿色区域修改为白色。
+
+        Returns:
+            None
+        """
         hsv = cv2.cvtColor(self.img, cv2.COLOR_BGR2HSV)
         lower_green = np.array([35, 43, 46])
         upper_green = np.array([77, 220, 255])
@@ -69,6 +86,12 @@ class Table:
         self.img[mask_green != 0] = color
 
     def get_str(self):
+        """
+        从HTML数组中获取字符串。
+
+        Returns:
+            str: 提取的字符串。
+        """
         str = ''
         for tr in self.html_arr:
             for cell in tr:
@@ -76,12 +99,15 @@ class Table:
         return str
 
     def check_html(self):
+        """
+        检查HTML表格的质量,如果识别效果不佳,则修改图像颜色。
+
+        Returns:
+            int: 返回1表示识别效果不佳,返回0表示识别效果良好。
+        """
         self.get_empty()
         html_str = self.get_str()
 
-        print(self.html)
-        print(self.html_arr)
-        print(self.empty)
         if (self.empty > 4 and self.empty > self.total // 4) or (
                 '项目' in html_str and '每份' in html_str and '营养素参考值' in html_str and np.max(
                 [len(a) for a in self.html_arr]) < 3):

+ 37 - 29
cores/post_decorators.py

@@ -4,9 +4,10 @@ import re
 
 @decorator
 def rule1_decorator(f, *args, **kwargs):
-    '''
-    predict_line = ['项目 ', '', '每100克营养素参考值%', '']
-    '''
+    """
+        处理表头第二格合并至第三格的情况
+        predict_line = ['项目 ', '', '每100克营养素参考值%', '']
+    """
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     idx = 0
@@ -26,9 +27,10 @@ def rule1_decorator(f, *args, **kwargs):
 
 @decorator
 def rule2_decorator(f, *args, **kwargs):
-    '''
-    predict_line = ['碳水化合物18.2克', '', '6%', '']
-    '''
+    """
+        处理碳水化合物这一行,第二格合并至第一格的问题
+        predict_line = ['碳水化合物18.2克', '', '6%', '']
+    """
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     idx = 0
@@ -48,13 +50,14 @@ def rule2_decorator(f, *args, **kwargs):
 
 @decorator
 def rule3_decorator(f, *args, **kwargs):
-    '''
-    ['患直质', '1.6克', '3%', '']
-    ['脂扇', '1.1', '19%', '']
-    ['碳水化合物', '勿18.2克', '6%', '']
-    ['能量.', '408千焦',	'5%']
-    ['——精', '2.9克']
-    '''
+    """
+        处理易错字
+        ['患直质', '1.6克', '3%', '']
+        ['脂扇', '1.1', '19%', '']
+        ['碳水化合物', '勿18.2克', '6%', '']
+        ['能量.', '408千焦',	'5%']
+        ['——精', '2.9克']
+    """
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
@@ -67,9 +70,10 @@ def rule3_decorator(f, *args, **kwargs):
 
 @decorator
 def rule4_decorator(f, *args, **kwargs):
-    '''
-    ['', '项目每一百克', '营养素参考值']
-    '''
+    """
+        处理表头第一格合并至第二格的问题
+        ['', '项目每100克', '营养素参考值']
+    """
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     try:
@@ -83,9 +87,10 @@ def rule4_decorator(f, *args, **kwargs):
 
 @decorator
 def rule5_decorator(f, *args, **kwargs):
-    '''
+    """
+        处理表头第三格合并至第二格的问题
         predict_line = ['项目 ', '每份(70g)营养素参考值%', '']
-    '''
+    """
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     try:
@@ -102,9 +107,10 @@ def rule5_decorator(f, *args, **kwargs):
 
 @decorator
 def rule6_decorator(f, *args, **kwargs):
-    '''
-    predict_line = ['项目 ', '', '每份(70g)营养素参考值%', '']
-    '''
+    """
+        处理表头第二格合并至第三格的问题
+        predict_line = ['项目 ', '', '每份(70g)营养素参考值%', '']
+    """
     predict_line = args[1]
     predict_line = f(*args, **kwargs)
     idx = 0
@@ -123,22 +129,24 @@ def rule6_decorator(f, *args, **kwargs):
 
 @decorator
 def rule7_decorator(f, *args, **kwargs):
-    '''
-    predict_line = ['项目 ', '', '每份(70g)营养素参考值%', '']
-    '''
+    """
+        处理项目缺一个字未识别出的问题
+        predict_line = ['项', '每份(70g)', '营养素参考值%', '']
+    """
     predict_line = f(*args, **kwargs)
     try:
         if '项目' in predict_line[0] or '项' in predict_line[0] or '目' in predict_line[0]:
             predict_line[0] = '项目'
     except IndexError as e:
-        print('rule6_decorator', e)
+        print('rule7_decorator', e)
     return predict_line
 
 @decorator
 def rule8_decorator(f, *args, **kwargs):
-    '''
-    predict_line = ['项目 ', '', '每份(70g)营养素参考值%', '']
-    '''
+    """
+        处理表头数据集中在第三格的问题
+        predict_line = ['', '', '项目每份(70g)营养素参考值%', '']
+    """
     predict_line = f(*args, **kwargs)
     try:
         if len(predict_line) >= 3 \
@@ -151,7 +159,7 @@ def rule8_decorator(f, *args, **kwargs):
             predict_line[1] = '每100克'
             predict_line[2] = '营养素参考值%'
     except IndexError as e:
-        print('rule6_decorator', e)
+        print('rule8_decorator', e)
     return predict_line
 
 

+ 5 - 0
cores/post_hander.py

@@ -7,6 +7,7 @@ class PostHandler:
         self.predict_html = predict_html
         self.format_lines = self._get_format_lines()
 
+    # 将二维列表处理为想要的富文本格式
     @property
     def format_predict_html(self):
         if self.format_lines:
@@ -40,10 +41,12 @@ class PostHandler:
         else:
             return self.predict_html
 
+    # 对每一行进行处理
     @combined_decorator
     def _format_predict_line(self, predict_line):
         return predict_line
 
+    # 对每一行进行处理
     def _get_format_lines(self):
         format_lines = []
         predict_lines = self._get_lines(self.predict_html)
@@ -53,6 +56,7 @@ class PostHandler:
             format_lines.append(line)
         return format_lines
 
+    # 获取每一行
     def _get_lines(self, html) -> List[str]:
         '''
         res:  ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
@@ -65,6 +69,7 @@ class PostHandler:
                 res.extend(m)
         return res
 
+    # 切分每一个格子
     def _split_to_words(self, line):
         '''
         line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'

+ 72 - 55
server.py

@@ -1,19 +1,11 @@
 # -*- coding: UTF-8 -*-
-import json
-from base64 import b64decode
-import base64
-
-import cv2
-import numpy as np
-from fastapi import FastAPI, Request
+from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from pydantic import BaseModel
-from paddleocr import PaddleOCR, PPStructure
+from paddleocr import PPStructure
 from sx_utils.sxweb import *
 from sx_utils.sximage import *
 import threading
-import os
-import re
 from sx_utils.sx_log import *
 import paddleclas
 
@@ -35,15 +27,12 @@ app.add_middleware(
 )
 
 table_engine_lock = threading.Lock()
-
+# 表格识别模型
 table_engine = PPStructure(layout=False,
                            table=True,
                            use_gpu=True,
                            show_log=True,
                            use_angle_cls=True,
-                           #    det_model_dir="models/det/det_table_v2",
-                           #    det_model_dir="models/det/det_table_v3",
-                           #    rec_model_dir="models/rec/rec_table_v1",
                            table_model_dir="models/table/SLANet_911")
 
 cls_lock = threading.Lock()
@@ -51,38 +40,17 @@ cls_lock = threading.Lock()
 cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
 
 
-# # 普通表格
-# table_engine = PPStructure(layout=False,
-#                            table=True,
-#                            use_gpu=use_gpu,
-#                            show_log=True,
-#                            det_model_dir="models/det/det_table_v2",
-#                            rec_model_dir="./models/rec/rec_table_v1",
-#                            table_model_dir="models/table/SLANet_v2")
-#
-# # 长度较长表格
-# table_engine1 = PPStructure(layout=False,
-#                             table=True,
-#                             use_gpu=use_gpu,
-#                             show_log=True,
-#                             det_model_dir="models/det/det_table_v1",
-#                             rec_model_dir="./models/rec/rec_table_v1",
-#                             table_model_dir="./models/table/SLAnet_v1")
-#
-# # 针对某些特殊情况的补充模型
-# table_engine2 = PPStructure(layout=False,
-#                             table=True,
-#                             use_gpu=use_gpu,
-#                             show_log=True,
-#                             det_model_dir="models/det/det_table_v3",
-#                             rec_model_dir="./models/rec/rec_table_v1",
-#                             table_model_dir="./models/table/SLAnet_v1")
-#
-#
-#
-
 # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
 def cal_html_to_chs(html):
+    """
+    将HTML中的表格数据提取并合并为中文字符串。
+
+    Parameters:
+        html (str): 输入的HTML字符串。
+
+    Returns:
+        int: 合并后的中文字符串长度。
+    """
     res = []
     rows = re.split('<tr>', html)
     for row in rows:
@@ -97,11 +65,20 @@ def cal_html_to_chs(html):
     rec_res = ''.join(res).replace(' ', '')
     rec_res = re.split('<tdcolspan="\w+">', rec_res)
     rec_res = ''.join(rec_res).replace(' ', '')
-    print(rec_res)
     return len(rec_res)
 
 
 def predict_cls(image, conf=0.8):
+    """
+    使用分类模型对图像进行预测,并返回预测结果。
+
+    Parameters:
+        image (np.ndarray): 输入的图像数组。
+        conf (float): 置信度阈值,默认为0.8。
+
+    Returns:
+        int: 预测结果的类别标签。
+    """
     try:
         cls_lock.acquire()
         result = cls_model.predict(image)
@@ -111,24 +88,40 @@ def predict_cls(image, conf=0.8):
         score = res[0]['scores'][0]
         label_name = res[0]['label_names'][0]
         print(f"score: {score}, label_name: {label_name}")
-        # print(conf)
         if score > conf:
             return int(label_name)
     return -1
 
 
 def rotate_to_zero(image, current_degree):
-    # cv2.imwrite('1.jpg', image)
+    """
+    将图像旋转至零度方向。
+
+    Parameters:
+        image (np.ndarray): 输入的图像数组。
+        current_degree (float): 当前的旋转角度。
+
+    Returns:
+        np.ndarray: 旋转后的图像数组。
+    """
     current_degree = current_degree // 90
     if current_degree == 0:
         return image
     to_rotate = (4 - current_degree) - 1
     image = cv2.rotate(image, to_rotate)
-    # cv2.imwrite('2.jpg', image)
     return image
 
 
 def get_zero_degree_image(img):
+    """
+    获取经零度方向旋转后的图像。
+
+    Parameters:
+        img (np.ndarray): 输入的图像数组。
+
+    Returns:
+        np.ndarray: 经零度方向旋转后的图像数组。
+    """
     step = 0.2
     for idx, i in enumerate([-1, 0, 1, 2]):
         if i >= 0:
@@ -146,11 +139,18 @@ def get_zero_degree_image(img):
 
 
 def table_res(im, ROTATE=-1):
+    """
+    获取图像中表格的识别结果和HTML字符串。
+
+    Parameters:
+        im (np.ndarray): 输入的图像数组。
+        ROTATE (int): 旋转角度,默认为-1。
+
+    Returns:
+        Tuple: 表格识别结果和HTML字符串。
+    """
     im = im.copy()
-    # cv2.imwrite('before-rotate.jpg', im)
-    # 获取正向图片
     img = get_zero_degree_image(im)
-    # cv2.imwrite('after-rotate.jpg', img)
     try:
         table_engine_lock.acquire()
         res = table_engine(img)
@@ -167,23 +167,40 @@ class TableInfo(BaseModel):
 
 @app.get("/ping")
 def ping():
+    """
+    用于检查服务是否存活的端点。
+
+    Returns:
+        str: 返回pong表示服务存活。
+    """
     return 'pong!!!!!!!!!'
 
 
 @app.post("/ocr_system/table")
 @web_try()
 def table(image: TableInfo):
+    """
+    对图像中的表格进行识别并返回HTML字符串。
+
+    Parameters:
+        image (TableInfo): 输入的图像信息。
+
+    Returns:
+        dict: 包含HTML字符串的字典。
+    """
+    # 转换图片格式
     img = base64_to_np(image.image)
+    # 进行表格识别
     res, html = table_res(img)
-    # print(html)
+    # 创建Table实例
     table = Table(html, img)
+    # 效果不好则重新识别
     if table.check_html():
         res, html = table_res(table.img)
 
     if html:
-        post_hander = PostHandler(html)
-        # print(post_hander.format_predict_html)
-        return {'html': post_hander.format_predict_html}
+        post_handler = PostHandler(html)
+        return {'html': post_handler.format_predict_html}
     else:
         raise Exception('无法识别')
 

+ 2 - 0
sx_utils/sximage.py

@@ -1,7 +1,9 @@
+import base64
 from base64 import b64decode
 import numpy as np
 import cv2
 
+# base64格式转numpy格式
 def base64_to_np(img_data):
     color_image_flag = 1
     img_data = img_data.split(',',1)[-1]