Pārlūkot izejas kodu

refactor anchor

Zhang Li 2 gadi atpakaļ
vecāks
revīzija
10563f1bf5
2 mainītis faili ar 95 papildinājumiem un 104 dzēšanām
  1. 88 97
      core/direction.py
  2. 7 7
      testing/all_test.py

+ 88 - 97
core/direction.py

@@ -1,5 +1,6 @@
 import re
 from dataclasses import dataclass
+from typing import Tuple
 
 import cv2
 import numpy as np
@@ -8,51 +9,87 @@ from paddleocr import PaddleOCR
 from core.line_parser import LineParser
 
 
-def get_rec_area(res):
-    """获得整张身份证的识别区域, 返回识别区域的中心点"""
-    boxes = []
-    for row in res:
-        for r in row:
-            boxes.extend(r.box)
-    boxes = np.stack(boxes)
-    l, t = np.min(boxes, 0)
-    r, b = np.max(boxes, 0)
-    # 识别区域的box
-    big_box = [[l, t], [r, t], [r, b], [l, b]]
-    w, h = (r - l, b - t)
-    return (l + r) / 2, (t + b) / 2, big_box
-
-
-def find_idno(res):
-    """寻找身份证号的识别区域以及中心点,根据身份证的w > h判断是否水平"""
-    for row in res:
-        for r in row:
-            txt = r.txt.replace('-', '').replace(' ', '')
-            box = r.box
-            txts = re.findall('\d{10,18}', txt)
-            if len(txts) > 0:
-                l, t = np.min(box, 0)
-                r, b = np.max(box, 0)
-                return txts[0], (l + r) / 2, (t + b) / 2, (r - l) > (b - t), box
-    return '', 0, 0, True, []
-
-
-def detect_angle(result):
+class OcrAnchor(object):
+    # 输入识别anchor的名字, 如身份证号
+    def __init__(self, name: str):
+        self.name = name
+
+    def get_rec_area(self, res) -> Tuple[float, float]:
+        """获得整张身份证的识别区域, 返回识别区域的中心点"""
+        boxes = []
+        for row in res:
+            for r in row:
+                boxes.extend(r.box)
+        boxes = np.stack(boxes)
+        l, t = np.min(boxes, 0)
+        r, b = np.max(boxes, 0)
+        # 识别区域的box
+        # big_box = [[l, t], [r, t], [r, b], [l, b]]
+        # w, h = (r - l, b - t)
+        return (l + r) / 2, (t + b) / 2
+
+    def is_anchor(self, txt, box) -> bool:
+        pass
+
+    def find_anchor(self, res) -> Tuple[bool, float, float]:
+        """寻找身份证号的识别区域以及中心点,根据身份证的w > h判断是否水平"""
+        for row in res:
+            for r in row:
+                txt = r.txt.replace('-', '').replace(' ', '')
+                box = r.box
+                if self.is_anchor(txt, box):
+                    l, t = np.min(box, 0)
+                    r, b = np.max(box, 0)
+                    return True, (l + r) / 2, (t + b) / 2
+        return False, 0., 0.
+
+    def locate_anchor(self, res, is_horizontal) -> int:
+        found, id_cx, id_cy = self.find_anchor(res)
+
+        # 如果识别不到身份证号
+        if not found: raise Exception(f'识别不到anchor{self.name}')
+        cx, cy = self.get_rec_area(res)
+        # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
+        # print(f'cx: {cx}, cy: {cy}')
+        if is_horizontal:
+            # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
+            return 0 if id_cy > cy else 2
+        else:
+            # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
+            return 1 if id_cx < cx else 3
+
+
+class FrontSideAnchor(OcrAnchor):
+    def __init__(self, name: str):
+        super(FrontSideAnchor, self).__init__(name)
+
+    def is_anchor(self, txt, box) -> bool:
+        txts = re.findall('\d{10,18}', txt)
+        if len(txts) > 0:
+            return True
+        return False
+
+
+class BackSideAnchor(OcrAnchor):
+    def __init__(self, name: str):
+        super(BackSideAnchor, self).__init__(name)
+
+    def is_anchor(self, txt, box) -> bool:
+        txt = txt.replace('.', '')
+        txts = re.findall('有效期', txt)
+        if len(txts) > 0:
+            return True
+        return False
+
+
+def detect_angle(result, ocr_anchor: OcrAnchor):
     lp = LineParser(result)
     res = lp.parse()
-    idno, id_cx, id_cy, is_horizon, id_box = find_idno(res)
-    # 如果识别不到身份证号
-    if not idno: raise Exception('识别不到身份证号')
-    cx, cy, big_box = get_rec_area(res)
-    # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
-    # print(f'cx: {cx}, cy: {cy}')
-
-    if is_horizon:
-        # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
-        return 0 if id_cy > cy else 2
-    else:
-        # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
-        return 1 if id_cx < cx else 3
+    print('------ angle ocr -------')
+    print(res)
+    print('------ angle ocr -------')
+    is_horizontal = lp.is_horizontal
+    return ocr_anchor.locate_anchor(res, is_horizontal)
 
 
 @dataclass
@@ -62,26 +99,21 @@ class AngleDetector(object):
 
     def detect_angle(self, img, image_type):
         image_type = int(image_type)
-        if image_type != 0:
-            return self._detect_back(img)
-
-        return self._detect_front(img)
-
-    def _detect_front(self, img):
+        ocr_anchor = BackSideAnchor('有效期') if image_type != 0 else FrontSideAnchor('身份证号')
         result = self.ocr.ocr(img, cls=True)
 
-        print('------ angle ocr -------')
-        print(result)
-        print('------ angle ocr -------')
-
         try:
-            return detect_angle(result), result
+            angle = detect_angle(result, ocr_anchor)
+            return angle, result
+
         except Exception as e:
             print(e)
+            # 如果第一次识别不到,旋转90度再识别
             img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
             result = self.ocr.ocr(img, cls=True)
-            angle = detect_angle(result)
-            return (angle-1+4)/4, result
+            angle = detect_angle(result, ocr_anchor)
+            # 旋转90度之后要重新计算角度
+            return (angle - 1 + 4) / 4, result
 
     def _detect_back(self, image):
         mask = np.zeros(image.shape, dtype=np.uint8)
@@ -117,44 +149,3 @@ class AngleDetector(object):
             print(f'top: {top_pixels}, bottom: {bottom_pixels}')
             angle = 1 if bottom_pixels <= top_pixels else 3
         return angle, None
-
-
-#
-#
-# def detect_angle(image):
-#     mask = np.zeros(image.shape, dtype=np.uint8)
-#     gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
-#     blur = cv2.GaussianBlur(gray, (3, 3), 0)
-#     adaptive = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 4)
-#
-#     cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
-#     cnts = cnts[0] if len(cnts) == 2 else cnts[1]
-#
-#     for c in cnts:
-#         area = cv2.contourArea(c)
-#         if area < 45000 and area > 20:
-#             cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
-#
-#     mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
-#     h, w = mask.shape
-#
-#     # Horizontal
-#     if w > h:
-#         left = mask[0:h, 0:0 + w // 2]
-#         right = mask[0:h, w // 2:]
-#         left_pixels = cv2.countNonZero(left)
-#         right_pixels = cv2.countNonZero(right)
-#         return 0 if left_pixels >= right_pixels else 180
-#     # Vertical
-#     else:
-#         top = mask[0:h // 2, 0:w]
-#         bottom = mask[h // 2:, 0:w]
-#         top_pixels = cv2.countNonZero(top)
-#         bottom_pixels = cv2.countNonZero(bottom)
-#         return 90 if bottom_pixels >= top_pixels else 270
-#
-#
-# if __name__ == '__main__':
-#     image = cv2.imread('d40.jpg')
-#     angle = detect_angle(image)
-#     print(angle)

+ 7 - 7
testing/all_test.py

@@ -172,13 +172,13 @@ import pytest
 @pytest.mark.parametrize(
     "image_path, status, orientation, expire_date",
     [
-        ('../images/all/1/01_0.jpg', '000', 0, '20220511-20410511'),
-        ('../images/all/1/02_0.jpg', '000', 3, '20180531-20280531'),
-        ('../images/all/1/03_180.jpg', '000', 2, '20170109-20270109'),
-        ('../images/all/1/04_90.jpg', '000', 1, '20190715-20390715'),
-        ('../images/all/1/05-270.jpg', '000', 3, '20140320-20240320'),
-        ('../images/all/1/small.png', '000', 0, '20190620-20290620'),
-        ('../images/all/1/special.jpg', '000', 0, '20190813-20290813'),
+        (Path('../images/all/1/01_0.jpg'), '000', 0, '20220511-20420511'),
+        (Path('../images/all/1/02_0.jpg'), '000', 3, '20180531-20280531'),
+        (Path('../images/all/1/03_180.jpg'), '000', 2, '20170109-20270109'),
+        (Path('../images/all/1/04_90.jpg'), '000', 1, '20190715-20390715'),
+        (Path('../images/all/1/05-270.jpg'), '000', 3, '20140320-20240320'),
+        (Path('../images/all/1/small.png'), '000', 0, '20190620-20290620'),
+        (Path('../images/all/1/special.jpg'), '000', 0, '20190813-20290813'),
     ]
 )
 def test_back_side(image_path, status, orientation, expire_date):