Browse Source

add direction

zeke-chin 2 years ago
parent
commit
eb81f3c24e
4 changed files with 222 additions and 64 deletions
  1. 179 36
      core/direction.py
  2. 2 12
      core/line_parser.py
  3. 38 13
      core/ocr.py
  4. 3 3
      core/parser.py

+ 179 - 36
core/direction.py

@@ -1,39 +1,182 @@
+import re
+
 import cv2
 import numpy as np
+from dataclasses import dataclass
+from enum import Enum
+from typing import Tuple, List
+
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+
+from core.line_parser import LineParser
+
+
+# 枚举
+class Direction(Enum):
+    TOP = 0
+    RIGHT = 1
+    BOTTOM = 2
+    LEFT = 3
+
+
+# 父类
+class OcrAnchor(object):
+    # anchor的名字, 如身份证号、承办人等
+    def __init__(self, name: str, d: List[Direction]):
+        self.name = name
+        self.direction = d
+
+        # 定义枚举字典
+        def t_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[1] < c[1] else 2
+            else:
+                return 1 if anchor[0] > c[0] else 3
+
+        def l_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[0] < c[0] else 2
+            else:
+                return 1 if anchor[1] < c[1] else 3
+
+        def b_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[1] > c[1] else 2
+            else:
+                return 1 if anchor[0] < c[0] else 3
+
+        def r_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[0] > c[0] else 2
+            else:
+                return 1 if anchor[1] > c[1] else 3
+
+        self.direction_funcs = {
+            Direction.TOP: t_func,
+            Direction.LEFT: l_func,
+            Direction.BOTTOM: b_func,
+            Direction.RIGHT: r_func
+        }
+
+    # pic中心点
+    def get_pic_center(self, res) -> Tuple[float, float]:
+        boxs = []
+        for row in res:
+            for r in row:
+                boxs.extend(r.box)
+        boxs = np.stack(boxs)
+        l, t = np.min(boxs, 0)
+        r, b = np.max(boxs, 0)
+        return (l + r) / 2, (t + b) / 2
+
+    # 是否有锚点
+    def is_anchor(self, txt, box):
+        pass
+
+    # 找锚点
+    def find_anchor(self, res):
+        for row in res:
+            for r in row:
+                if self.is_anchor(r.txt, r.box):
+                    l, t = np.min(r.box, 0)
+                    r, b = np.max(r.box, 0)
+                    return True, (l + r) / 2, (t + b) / 2
+                    # return True, r.center[0], r.center[1]
+        return False, 0., 0.
+
+    # get angle
+    def locate_anchor(self, res, is_horizontal):
+        found, a_cx, a_cy = self.find_anchor(res)
+        cx, cy = self.get_pic_center(res)
+
+        if found is False: raise Exception(f'识别不到anchor{self.name}')
+
+        pre = None
+        for d in self.direction:
+            angle_func = self.direction_funcs.get(d, None)
+            angle = angle_func((a_cx, a_cy), (cx, cy), is_horizontal)
+            if pre is None:
+                pre = angle
+            else:
+                if pre != angle:
+                    raise Exception('angle is not compatible')
+        return pre
+
+
+# 子类1 户口本首页1
+class FrontAnchor(OcrAnchor):
+    def __init__(self, name: str, d: List[Direction]):
+        super(FrontAnchor, self).__init__(name, d)
+
+    def is_anchor(self, txt, box):
+        txts = re.findall('承办人', txt)
+        if len(txts) > 0:
+            return True
+        return False
+
+    def locate_anchor(self, res, is_horizontal):
+        return super(FrontAnchor, self).locate_anchor(res, is_horizontal)
+
+
+# 子类2 常驻人口页0
+class PeopleAnchor(OcrAnchor):
+    def __init__(self, name: str, d: List[Direction]):
+        super(PeopleAnchor, self).__init__(name, d)
+
+    def is_anchor(self, txt, box):
+        txts = re.findall('常住', txt)
+        if len(txts) > 0:
+            return True
+        return False
+
+    def locate_anchor(self, res, is_horizontal):
+        return super(PeopleAnchor, self).locate_anchor(res, is_horizontal)
+
+
+# 调用以上 🔧工具
+# <- ocr_生数据
+# == ocr_熟数据(行处理后)
+# -> 角度0/1/2/3
+def detect_angle(result, ocr_anchor: OcrAnchor):
+    lp = LineParser(result)
+    res = lp.parse()
+    print('------ angle ocr -------')
+    print(res)
+    print('------ angle ocr -------')
+    is_horizontal = lp.is_horizontal
+    return ocr_anchor.locate_anchor(res, is_horizontal)
+
+
+@dataclass
+class AngleDetector(object):
+    """
+    角度检测器
+    """
+    ocr: PaddleOCR
+
+    # 角度检测器
+    # <- img(cv2格式)  img_type
+    # == result <- img(cv2)
+    # -> angle       result(ocr生)
+    def detect_angle(self, img, image_type):
+        image_type = int(image_type)
+        ocr_anchor = PeopleAnchor('常住', [Direction.TOP]) if image_type == 0 else FrontAnchor('承办人', [Direction.BOTTOM,
+                                                                                                       Direction.LEFT])
+
+        result = self.ocr.ocr(img, cls=True)
+
+        try:
+            angle = detect_angle(result, ocr_anchor)
+            return angle, result
+
 
-def detect_angle(image):
-    mask = np.zeros(image.shape, dtype=np.uint8)
-    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
-    blur = cv2.GaussianBlur(gray, (3,3), 0)
-    adaptive = cv2.adaptiveThreshold(blur,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV,15,4)
-
-    cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
-    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
-
-    for c in cnts:
-        area = cv2.contourArea(c)
-        if area < 45000 and area > 20:
-            cv2.drawContours(mask, [c], -1, (255,255,255), -1)
-
-    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
-    h, w = mask.shape
-    
-    # Horizontal
-    if w > h:
-        left = mask[0:h, 0:0+w//2]
-        right = mask[0:h, w//2:]
-        left_pixels = cv2.countNonZero(left)
-        right_pixels = cv2.countNonZero(right)
-        return 0 if left_pixels >= right_pixels else 180
-    # Vertical
-    else:
-        top = mask[0:h//2, 0:w]
-        bottom = mask[h//2:, 0:w]
-        top_pixels = cv2.countNonZero(top)
-        bottom_pixels = cv2.countNonZero(bottom)
-        return 90 if bottom_pixels >= top_pixels else 270
-
-if __name__ == '__main__':
-    image = cv2.imread('d40.jpg')
-    angle = detect_angle(image)
-    print(angle)
+        except Exception as e:
+            print(e)
+            # 如果第一次识别不到,旋转90度再识别
+            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+            result = self.ocr.ocr(img, cls=True)
+            angle = detect_angle(result, ocr_anchor)
+            # 旋转90度之后要重新计算角度
+            return (angle - 1 + 4) % 4, result

+ 2 - 12
core/line_parser.py

@@ -69,16 +69,6 @@ class LineParser(object):
     def confidence(self):
         return np.mean([r.conf for r in self.ocr_res])
 
-    def sorted(self, raw_res):
-        out_idx = self.is_horizontal + 0
-        in_idx = not self.is_horizontal + 0
-        sorted_res = []
-        res_out = sorted([list(r) for r in raw_res], key=lambda x: x[0].lt[out_idx])
-        for res_in in res_out:
-            sorted_res.append(sorted([rr2 for rr2 in res_in], key=lambda r: r.lt[in_idx]))
-
-        return sorted_res
-
     # 处理器函数
     def parse(self, eps=40.0):
         # 存返回值
@@ -117,5 +107,5 @@ class LineParser(object):
 
                     res_row.add(res_j)
             res.append(res_row)
-
-        return self.sorted(res)
+        idx = self.is_horizontal + 0
+        return sorted([list(r) for r in res], key=lambda x: x[0].lt[idx])

+ 38 - 13
core/ocr.py

@@ -1,5 +1,7 @@
 from dataclasses import dataclass
 
+import cv2
+
 from core.line_parser import LineParser
 from core.parser import *
 from core.direction import *
@@ -7,27 +9,45 @@ import numpy as np
 from paddleocr import PaddleOCR
 
 
+# <- 传入pic pic_type
+# 1. 旋转pic  (to 正向)
+# 2. 重写识别pic  (get res)
+# 3. 行处理res  (get res)
+# 4. 对res字段逻辑识别  (get dict)
+# -> dict
+
 @dataclass
 class IdCardOcr:
     ocr: PaddleOCR
+    # 角度探测器
+    angle_detector: AngleDetector
 
+    # master
     def predict(self, image: np.ndarray, image_type: str):
-        image, angle = self._pre_process(image)
-        txts, confs, result = self._ocr(image)
+        img_type = int(image_type)
+
+        img, angle, result = self._rotate_img(image, img_type)
+        print(f'---------- detect angle: {angle} 图片角度 ----------')
+        if image_type == 0:
+            if angle != 0:
+                # 角度不为0需要重新识别,字面
+                _, _, result = self._ocr(image)
+        else:
+            _, _, result = self._ocr(image)
 
-        # parser = PeopleRegBookParser(txts, confs)
         return self._post_process(result, angle, image_type)
 
-    def _pre_process(self, image) -> (np.ndarray, int):
-        angle = detect_angle(image)
-        print(angle)  # 逆时针
-        # if angle == 180:
-        #     image = cv2.rotate(image, cv2.ROTATE_180)
-        if angle == 90:
+    # 检测角度
+    def _rotate_img(self, image, image_type) -> (np.ndarray, int):
+        angle, result = self.angle_detector.detect_angle(image, image_type)
+        if angle == 1:
             image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
-        if angle == 270:
-            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
-        return image, angle
+        if angle == 2:
+            image = cv2.rotate(image, cv2.ROTATE_180)
+        if angle == 3:
+            image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
+        print(angle)
+        return image, angle, result
 
     def _ocr(self, image):
         # 获取模型检测结果
@@ -46,6 +66,8 @@ class IdCardOcr:
         return txts, confs, result
 
     def _post_process(self, result, angle: int, image_type: str):
+        # 行处理
+
         line_parser = LineParser(result)
         line_result = line_parser.parse()
         conf = line_parser.confidence
@@ -62,10 +84,13 @@ class IdCardOcr:
         else:
             raise Exception('未传入 image_type')
 
+        # 字段逻辑处理后对res(dict)
         ocr_res = parser.parse()
+
         res = {
             "confidence": conf,
-            "orientation": (4 - angle // 90) % 4,  # 原angle是逆时针,转成顺时针
+            "img_type": str(image_type),
+            "orientation": angle,  # 原angle是逆时针,转成顺时针
             **ocr_res
         }
 

+ 3 - 3
core/parser.py

@@ -49,7 +49,7 @@ class Parser(object):
         return self.res
 
 
-# 1 常驻人口
+# 1 户口本首
 class FrontRegBookParser(Parser):
     def type_(self):
         """
@@ -116,7 +116,7 @@ class FrontRegBookParser(Parser):
             self.res["address_region"] = RecItem(region, conf)
             self.res["address_detail"] = RecItem(detail, conf)
 
-        self.res['address'].text = df
+        self.res['address'].text = province + city + region + detail
 
     # 存入
     def parse(self):
@@ -125,7 +125,7 @@ class FrontRegBookParser(Parser):
         return {key: self.res[key].to_dict() for key in self.keys}
 
 
-# 0 户口本首
+# 0 常驻人口
 class PeopleRegBookParser(Parser):
 
     def full_name(self):