Browse Source

add Notes and fix lineparser

zeke-chin 2 years ago
parent
commit
7842203a3a
5 changed files with 70 additions and 59 deletions
  1. 28 42
      core/direction.py
  2. 6 3
      core/line_parser.py
  3. 23 5
      core/ocr.py
  4. 2 9
      core/parser.py
  5. 11 0
      run.py

+ 28 - 42
core/direction.py

@@ -15,12 +15,10 @@ class Directoin(Enum):
     RIGHT = 1
     BOTTOM = 2
     LEFT = 3
-    TOPRIGHT = 4
-    BOTTOMRIGHT = 5
-    BOTTOMLEFT = 6
-    TOPLEFT = 7
 
 
+
+# 父类
 class OcrAnchor(object):
     # 输入识别anchor的名字, 如身份证号
     def __init__(self, name: str, d: Directoin):
@@ -59,6 +57,7 @@ class OcrAnchor(object):
             Directoin.RIGHT: r_func,
         }
 
+    # 获取中心区域坐标 -> (x, y)
     def get_rec_area(self, res) -> Tuple[float, float]:
         """获得整张身份证的识别区域, 返回识别区域的中心点"""
         boxes = []
@@ -73,11 +72,15 @@ class OcrAnchor(object):
         # w, h = (r - l, b - t)
         return (l + r) / 2, (t + b) / 2
 
+    # 判断是否是 锚点
     def is_anchor(self, txt, box) -> bool:
         pass
 
+    # 找 锚点 -> 锚点坐标
     def find_anchor(self, res) -> Tuple[bool, float, float]:
-        """寻找身份证号的识别区域以及中心点,根据身份证的w > h判断是否水平"""
+        """
+        寻找身份证号的识别区域以及中心点
+        """
         for row in res:
             for r in row:
                 txt = r.txt.replace('-', '').replace(' ', '')
@@ -88,6 +91,8 @@ class OcrAnchor(object):
                     return True, (l + r) / 2, (t + b) / 2
         return False, 0., 0.
 
+    # 定位 锚点 -> 角度
+    # -> 锚点(x, y)  pic(x, y) is_horizontal
     def locate_anchor(self, res, is_horizontal) -> int:
         found, id_cx, id_cy = self.find_anchor(res)
 
@@ -96,6 +101,7 @@ class OcrAnchor(object):
         cx, cy = self.get_rec_area(res)
         # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
         # print(f'cx: {cx}, cy: {cy}')
+        # 用k->get->func ==> f()
         f = self.direction_funcs.get(self.direction, None)
 
         return f((id_cx, id_cy), (cx, cy), is_horizontal)
@@ -108,6 +114,7 @@ class OcrAnchor(object):
         #     return 1 if id_cx < cx else 3
 
 
+# 子类1 人像面
 class FrontSideAnchor(OcrAnchor):
     def __init__(self, name: str, d: Directoin):
         super(FrontSideAnchor, self).__init__(name, d)
@@ -122,6 +129,7 @@ class FrontSideAnchor(OcrAnchor):
         return super(FrontSideAnchor, self).locate_anchor(res, is_horizontal)
 
 
+# 子类2 国徽面
 class BackSideAnchor(OcrAnchor):
     def __init__(self, name: str, d: Directoin):
         super(BackSideAnchor, self).__init__(name, d)
@@ -137,6 +145,10 @@ class BackSideAnchor(OcrAnchor):
         return super(BackSideAnchor, self).locate_anchor(res, is_horizontal)
 
 
+# 调用以上 🔧工具
+# <- ocr_生数据
+# == ocr_熟数据(行处理后)
+# -> 角度0/1/2/3
 def detect_angle(result, ocr_anchor: OcrAnchor):
     lp = LineParser(result)
     res = lp.parse()
@@ -148,13 +160,22 @@ def detect_angle(result, ocr_anchor: OcrAnchor):
 
 
 @dataclass
-# 角度检测器
 class AngleDetector(object):
+    """
+    角度检测器
+    """
     ocr: PaddleOCR
 
+    # 角度检测器
+    # <- img(cv2格式)  img_type
+    # == result <- img(cv2)
+    # -> angle       result(ocr生)
     def detect_angle(self, img, image_type):
         image_type = int(image_type)
-        ocr_anchor = BackSideAnchor('有效期', Directoin.BOTTOM) if image_type != 0 else FrontSideAnchor('身份证号', Directoin.BOTTOM)
+
+        # 初始化anchor对象
+        ocr_anchor = BackSideAnchor('有效期', Directoin.BOTTOM) if image_type != 0 else FrontSideAnchor('身份证号',
+                                                                                                             Directoin.BOTTOM)
         result = self.ocr.ocr(img, cls=True)
 
         try:
@@ -169,38 +190,3 @@ class AngleDetector(object):
             angle = detect_angle(result, ocr_anchor)
             # 旋转90度之后要重新计算角度
             return (angle - 1 + 4) % 4, result
-
-    def _detect_back(self, image):
-        mask = np.zeros(image.shape, dtype=np.uint8)
-        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
-        blur = cv2.GaussianBlur(gray, (3, 3), 0)
-        adaptive = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 4)
-
-        cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
-        cnts = cnts[0] if len(cnts) == 2 else cnts[1]
-
-        for c in cnts:
-            area = cv2.contourArea(c)
-            if area < 45000 and area > 20:
-                cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
-
-        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
-        h, w = mask.shape
-
-        # Horizontal
-        if w > h:
-            left = mask[0:h, 0:0 + w // 2]
-            right = mask[0:h, w // 2:]
-            left_pixels = cv2.countNonZero(left)
-            right_pixels = cv2.countNonZero(right)
-            print(f'left: {left_pixels}, right: {right_pixels}')
-            angle = 0 if left_pixels >= right_pixels else 2
-        # Vertical
-        else:
-            top = mask[0:h // 2, 0:w]
-            bottom = mask[h // 2:, 0:w]
-            top_pixels = cv2.countNonZero(top)
-            bottom_pixels = cv2.countNonZero(bottom)
-            print(f'top: {top_pixels}, bottom: {bottom_pixels}')
-            angle = 1 if bottom_pixels <= top_pixels else 3
-        return angle, None

+ 6 - 3
core/line_parser.py

@@ -1,6 +1,7 @@
 import numpy as np
 from dataclasses import dataclass
 
+
 # result 对象
 @dataclass
 class OcrResult(object):
@@ -34,7 +35,7 @@ class OcrResult(object):
     def center(self):
         l, t = self.lt
         r, b = self.rb
-        return [(r + l)/2, (b + t)/2]
+        return [(r + l) / 2, (b + t) / 2]
 
     def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
         if is_horizontal:
@@ -108,5 +109,7 @@ class LineParser(object):
 
                     res_row.add(res_j)
             res.append(res_row)
-        idx = self.is_horizontal + 0
-        return sorted([sorted(list(r), key=lambda x: x.lt[idx]) for r in res], key=lambda x: x[0].lt[idx])
+        # 对于身份证
+        # 进入line_parser 时已经水平
+        # 故 外层是按x排序 里层按y值排序
+        return sorted([sorted(list(r), key=lambda x: x.lt[0]) for r in res], key=lambda x: x[0].lt[1])

+ 23 - 5
core/ocr.py

@@ -7,31 +7,45 @@ from core.direction import *
 import numpy as np
 from paddleocr import PaddleOCR
 
-
+# <- 传入pic pic_type
+# 1. 旋转pic  (to 正向)
+# 2. 重写识别pic  (get res)
+# 3. 行处理res  (get res)
+# 4. 对res字段逻辑识别  (get dict)
+# -> dict
+# 身份证OCR
 @dataclass
 class IdCardOcr:
     ocr: PaddleOCR
+    # 角度探测器
     angle_detector: AngleDetector
 
     # 检测
-    def predict(self, image: np.ndarray, image_type):
+    # <- 传入pic pic_type
+    # -> dict
+    def predict(self, image: np.ndarray, image_type) -> ():
         image_type = int(image_type)
 
+        # 旋转后img angle result(生ocr)
         image, angle, result = self._pre_process(image, image_type)
-        print(f'---------- detect angle: {angle} --------')
+        print(f'---------- detect angle: {angle} 角度 --------')
         if image_type == 0:
             if angle != 0:
                 # 角度不为0需要重新识别,字面
                 _, _, result = self._ocr(image)
         else:
             _, _, result = self._ocr(image)
+        # ==> result(正向img-> 生ocr)
 
         return self._post_process(result, angle, image_type)
 
+    # 预处理(旋转图片)
+    # <- img(cv2) img_type
+    # -> 正向的img(旋转后) 源img角度 result(ocr生)
     def _pre_process(self, image, image_type) -> (np.ndarray, int, Any):
+        # pic角度 result(ocr生)
         angle, result = self.angle_detector.detect_angle(image, image_type)
 
-        # angle = detect_angle(image)
         if angle == 1:
             image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
         print(angle)  # 逆时针
@@ -42,8 +56,8 @@ class IdCardOcr:
 
         return image, angle, result
 
+    # 获取模型检测结果
     def _ocr(self, image):
-        # 获取模型检测结果
         result = self.ocr.ocr(image, cls=True)
         print("------------------")
         print(result)
@@ -58,6 +72,9 @@ class IdCardOcr:
         # print("......................................")
         return txts, confs, result
 
+    # <- result(正向img_生ocr) angle img_type
+    # == 对 正向img_res 进行[行处理]
+    # -> 最后要返回的结果 dict
     def _post_process(self, result, angle: int, image_type):
         line_parser = LineParser(result)
         line_result = line_parser.parse()
@@ -73,6 +90,7 @@ class IdCardOcr:
         else:
             raise Exception('无法识别')
 
+        # 字段逻辑处理后对res(dict)
         ocr_res = parser.parse()
 
         res = {

+ 2 - 9
core/parser.py

@@ -152,14 +152,7 @@ class FrontParser(Parser):
                     self.res["ethnicity"] = RecItem(txt.split("族")[-1], conf)
                     return
 
-        # for nation in self.result[1]:
-        #     txt = nation.txt
-        #     conf = nation.conf
-        #     res = re.findall(".*族[\u4e00-\u9fa5]+", txt)
-        #
-        #     if len(res) > 0:
-        #         self.res["ethnicity"] = RecItem(res[0].split("族")[-1], conf)
-        #         return
+
 
     def address(self):
         """
@@ -243,7 +236,7 @@ class FrontParser(Parser):
 
 
 class BackParser(Parser):
-    def __init__(self, ocr_results: List[OcrResult]):
+    def __init__(self, ocr_results: List[List[OcrResult]]):
         Parser.__init__(self, ocr_results)
 
     def expire_date(self):

+ 11 - 0
run.py

@@ -0,0 +1,11 @@
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)