Bläddra i källkod

update angle detector and model

Zhang Li 2 år sedan
förälder
incheckning
5986d6402b

+ 2 - 1
Dockerfile

@@ -90,7 +90,8 @@ autorestart=true\n\
 startretries=0\n\
 redirect_stderr=true\n\
 stdout_logfile=/var/log/be.log\n\
-stdout_logfile_maxbytes=0\n\
+stdout_logfile_maxbytes=50MB\n\
+environment=PYTHONUNBUFFERED=1\n\
 " > /etc/supervisor/conf.d/be.conf
 
 ARG VERSION

BIN
ch_ppocr_server_v2.0_det_infer/inference.pdiparams


BIN
ch_ppocr_server_v2.0_det_infer/inference.pdiparams.info


BIN
ch_ppocr_server_v2.0_det_infer/inference.pdmodel


BIN
ch_ppocr_server_v2.0_rec_infer/inference.pdiparams


BIN
ch_ppocr_server_v2.0_rec_infer/inference.pdiparams.info


BIN
ch_ppocr_server_v2.0_rec_infer/inference.pdmodel


+ 145 - 75
core/direction.py

@@ -1,90 +1,160 @@
-import dataclasses
+import re
+from dataclasses import dataclass
 
 import cv2
 import numpy as np
-
-from dataclasses import dataclass
 from paddleocr import PaddleOCR
 
+from core.line_parser import LineParser
+
+
+def get_rec_area(res):
+    """获得整张身份证的识别区域, 返回识别区域的中心点"""
+    boxes = []
+    for row in res:
+        for r in row:
+            boxes.extend(r.box)
+    boxes = np.stack(boxes)
+    l, t = np.min(boxes, 0)
+    r, b = np.max(boxes, 0)
+    # 识别区域的box
+    big_box = [[l, t], [r, t], [r, b], [l, b]]
+    w, h = (r - l, b - t)
+    return (l + r) / 2, (t + b) / 2, big_box
+
+
+def find_idno(res):
+    """寻找身份证号的识别区域以及中心点,根据身份证的w > h判断是否水平"""
+    for row in res:
+        for r in row:
+            txt = r.txt.replace('-', '').replace(' ', '')
+            box = r.box
+            txts = re.findall('\d{10,18}', txt)
+            if len(txts) > 0:
+                l, t = np.min(box, 0)
+                r, b = np.max(box, 0)
+                return txts[0], (l + r) / 2, (t + b) / 2, (r - l) > (b - t), box
+    return '', 0, 0, True, []
+
+
+def detect_angle(result):
+    lp = LineParser(result)
+    res = lp.parse()
+    idno, id_cx, id_cy, is_horizon, id_box = find_idno(res)
+    # 如果识别不到身份证号
+    if not idno: raise Exception('识别不到身份证号')
+    cx, cy, big_box = get_rec_area(res)
+    # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
+    # print(f'cx: {cx}, cy: {cy}')
+
+    if is_horizon:
+        # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
+        return 0 if id_cy > cy else 2
+    else:
+        # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
+        return 1 if id_cx < cx else 3
+
 
 @dataclass
 # 角度检测器
 class AngleDetector(object):
     ocr: PaddleOCR
 
-    def detect_angle(self, img, result) -> int:
-        wc = 0
-        hc = 0
-        count_0 = 0
-        count_180 = 0
-        angle = 0
+    def detect_angle(self, img, image_type):
+        image_type = int(image_type)
+        if image_type != 0:
+            return self._detect_back(img)
+
+        return self._detect_front(img)
+
+    def _detect_front(self, img):
+        result = self.ocr.ocr(img, cls=True)
+
+        print('------ angle ocr -------')
         print(result)
-        for res in result:
-            txt = res[1][0]
-            if '号' not in txt: continue
-            a = np.array(res[0])
-            l, t = np.min(a, axis=0).tolist()
-            r, b = np.max(a, axis=0).tolist()
-            l, t, r, b = list(map(int, [l, t, r, b]))
-            if b - t > r - l:
-                hc += 1
-            else:
-                wc += 1
-            imgb = img[t:b, l:r, :]
-            r = self.ocr.ocr(imgb, det=False, rec=False, cls=True)
-            print(f'ocr angle: {r}')
-            if int(r[0][0]) == 180:
-                count_180 += 1
-            else:
-                count_0 += 1
-        if hc >= wc:
-            if count_0 >= count_180:
-                angle = 90
-            else:
-                angle = 270
+        print('------ angle ocr -------')
+
+        try:
+            return detect_angle(result), result
+        except Exception as e:
+            print(e)
+            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+            result = self.ocr.ocr(img, cls=True)
+            angle = detect_angle(result)
+            return (angle-1+4)/4, result
+
+    def _detect_back(self, image):
+        mask = np.zeros(image.shape, dtype=np.uint8)
+        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+        blur = cv2.GaussianBlur(gray, (3, 3), 0)
+        adaptive = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 4)
+
+        cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+        cnts = cnts[0] if len(cnts) == 2 else cnts[1]
+
+        for c in cnts:
+            area = cv2.contourArea(c)
+            if area < 45000 and area > 20:
+                cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
+
+        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
+        h, w = mask.shape
+
+        # Horizontal
+        if w > h:
+            left = mask[0:h, 0:0 + w // 2]
+            right = mask[0:h, w // 2:]
+            left_pixels = cv2.countNonZero(left)
+            right_pixels = cv2.countNonZero(right)
+            print(f'left: {left_pixels}, right: {right_pixels}')
+            angle = 0 if left_pixels >= right_pixels else 2
+        # Vertical
         else:
-            if count_0 > count_180:
-                angle = 0
-            else:
-                angle = 180
-
-        return angle
-
-
-def detect_angle(image):
-    mask = np.zeros(image.shape, dtype=np.uint8)
-    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
-    blur = cv2.GaussianBlur(gray, (3, 3), 0)
-    adaptive = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 4)
-
-    cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
-    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
-
-    for c in cnts:
-        area = cv2.contourArea(c)
-        if area < 45000 and area > 20:
-            cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
-
-    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
-    h, w = mask.shape
-
-    # Horizontal
-    if w > h:
-        left = mask[0:h, 0:0 + w // 2]
-        right = mask[0:h, w // 2:]
-        left_pixels = cv2.countNonZero(left)
-        right_pixels = cv2.countNonZero(right)
-        return 0 if left_pixels >= right_pixels else 180
-    # Vertical
-    else:
-        top = mask[0:h // 2, 0:w]
-        bottom = mask[h // 2:, 0:w]
-        top_pixels = cv2.countNonZero(top)
-        bottom_pixels = cv2.countNonZero(bottom)
-        return 90 if bottom_pixels >= top_pixels else 270
+            top = mask[0:h // 2, 0:w]
+            bottom = mask[h // 2:, 0:w]
+            top_pixels = cv2.countNonZero(top)
+            bottom_pixels = cv2.countNonZero(bottom)
+            print(f'top: {top_pixels}, bottom: {bottom_pixels}')
+            angle = 1 if bottom_pixels <= top_pixels else 3
+        return angle, None
 
 
-if __name__ == '__main__':
-    image = cv2.imread('d40.jpg')
-    angle = detect_angle(image)
-    print(angle)
+#
+#
+# def detect_angle(image):
+#     mask = np.zeros(image.shape, dtype=np.uint8)
+#     gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+#     blur = cv2.GaussianBlur(gray, (3, 3), 0)
+#     adaptive = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 15, 4)
+#
+#     cnts = cv2.findContours(adaptive, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+#     cnts = cnts[0] if len(cnts) == 2 else cnts[1]
+#
+#     for c in cnts:
+#         area = cv2.contourArea(c)
+#         if area < 45000 and area > 20:
+#             cv2.drawContours(mask, [c], -1, (255, 255, 255), -1)
+#
+#     mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
+#     h, w = mask.shape
+#
+#     # Horizontal
+#     if w > h:
+#         left = mask[0:h, 0:0 + w // 2]
+#         right = mask[0:h, w // 2:]
+#         left_pixels = cv2.countNonZero(left)
+#         right_pixels = cv2.countNonZero(right)
+#         return 0 if left_pixels >= right_pixels else 180
+#     # Vertical
+#     else:
+#         top = mask[0:h // 2, 0:w]
+#         bottom = mask[h // 2:, 0:w]
+#         top_pixels = cv2.countNonZero(top)
+#         bottom_pixels = cv2.countNonZero(bottom)
+#         return 90 if bottom_pixels >= top_pixels else 270
+#
+#
+# if __name__ == '__main__':
+#     image = cv2.imread('d40.jpg')
+#     angle = detect_angle(image)
+#     print(angle)

+ 6 - 4
core/line_parser.py

@@ -38,9 +38,11 @@ class OcrResult(object):
 
     def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
         if is_horizontal:
-            return abs(self.lt[1] - b.lt[1]) < eps
+            eps = 0.5 * (self.wh[1] + b.wh[1])
+            return abs(self.center[1] - b.center[1]) < eps
         else:
-            return abs(self.rb[0] - b.rb[0]) < eps
+            eps = 0.5 * (self.wh[0] + b.wh[0])
+            return abs(self.center[0] - b.center[0]) < eps
 
 
 # 行处理器
@@ -50,7 +52,7 @@ class LineParser(object):
         for re in ocr_raw_result:
             o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
             self.ocr_res.append(o)
-        self.eps = self.avg_height * 0.66
+        self.eps = self.avg_height * 0.7
 
     @property
     def is_horizontal(self):
@@ -107,4 +109,4 @@ class LineParser(object):
                     res_row.add(res_j)
             res.append(res_row)
         idx = self.is_horizontal + 0
-        return sorted([list(r) for r in res], key=lambda x: x[0].lt[idx])
+        return sorted([sorted(list(r), key=lambda x: x.lt[idx]) for r in res], key=lambda x: x[0].lt[idx])

+ 28 - 16
core/ocr.py

@@ -1,4 +1,5 @@
 from dataclasses import dataclass
+from typing import Any
 
 from core.line_parser import LineParser
 from core.parser import *
@@ -13,25 +14,33 @@ class IdCardOcr:
     angle_detector: AngleDetector
 
     # 检测
-    def predict(self, image: np.ndarray, image_type: str = '0'):
-        # image, angle = self._pre_process(image)
-        # 识别出的 => 字段、置信度、(字段,置信度)
-        txts, confs, result = self._ocr(image)
-        # 角度
-        angle = self.angle_detector.detect_angle(image, result)
+    def predict(self, image: np.ndarray, image_type):
+        image_type = int(image_type)
+
+        image, angle, result = self._pre_process(image, image_type)
+        print(f'---------- detect angle: {angle} --------')
+        if image_type == 0:
+            if angle != 0:
+                # 角度不为0需要重新识别,字面
+                _, _, result = self._ocr(image)
+        else:
+            _, _, result = self._ocr(image)
 
         return self._post_process(result, angle, image_type)
 
-    def _pre_process(self, image) -> (np.ndarray, int):
-        angle = detect_angle(image)
+    def _pre_process(self, image, image_type) -> (np.ndarray, int, Any):
+        angle, result = self.angle_detector.detect_angle(image, image_type)
+
+        # angle = detect_angle(image)
+        if angle == 1:
+            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
         print(angle)  # 逆时针
-        if angle == 180:
+        if angle == 2:
             image = cv2.rotate(image, cv2.ROTATE_180)
-        if angle == 90:
+        if angle == 3:
             image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
-        if angle == 270:
-            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
-        return image, angle
+
+        return image, angle, result
 
     def _ocr(self, image):
         # 获取模型检测结果
@@ -49,9 +58,12 @@ class IdCardOcr:
         # print("......................................")
         return txts, confs, result
 
-    def _post_process(self, result, angle: int, image_type: str):
+    def _post_process(self, result, angle: int, image_type):
         line_parser = LineParser(result)
         line_result = line_parser.parse()
+        print('-------------')
+        print(line_result)
+        print('-------------')
         conf = line_parser.confidence
 
         if int(image_type) == 0:
@@ -65,8 +77,8 @@ class IdCardOcr:
 
         res = {
             "confidence": conf,
-            "card_type": image_type,
-            "orientation": angle // 90,  # 原angle是逆时针,转成顺时针
+            "card_type": str(image_type),
+            "orientation": angle,  # 原angle是逆时针,转成顺时针
             **ocr_res
         }
         print(res)

+ 8 - 5
core/parser.py

@@ -6,7 +6,7 @@ from collections import defaultdict
 import numpy as np
 import cpca
 from typing import List
-
+from zhon.hanzi import punctuation
 from core.line_parser import OcrResult
 
 
@@ -16,7 +16,7 @@ class RecItem:
     confidence: float = 0.
 
     def to_dict(self):
-        return {"text": self.text, "confidence": np.nan_to_num(self.confidence)}
+        return {"text": self.text.strip(), "confidence": np.nan_to_num(self.confidence)}
 
 
 class Parser(object):
@@ -172,7 +172,9 @@ class FrontParser(Parser):
             for r in row:
                 txt = r.txt
                 if '性别' in txt or '出生' in txt or '民族' in txt: continue
-
+                punctuation_str = punctuation
+                for i in punctuation:
+                    txt = txt.replace(i, '')
                 if (
                         "住址" in txt
                         or "址" in txt
@@ -189,6 +191,7 @@ class FrontParser(Parser):
                         or "旗" in txt
                         or "号" in txt
                         or "户" in txt
+                        or "室" in txt
                 ):
                     # if "住址" in txt or "省" in txt or "址" in txt:
                     if "住址" in txt or "址" in txt:
@@ -248,9 +251,9 @@ class BackParser(Parser):
             for r in row:
                 txt = r.txt
                 txt = txt.replace('.', '')
-                res = re.findall('\d{8}\-\d{8}', txt)
+                res = re.findall('\d{8}\-\d{4}', txt)
                 if res:
-                    self.res["expire_date"] = RecItem(res[0], r.conf)
+                    self.res["expire_date"] = RecItem(res[0]+res[0][4:8], r.conf)
                     return
                 res = re.findall('\d{8}\-长期', txt)
                 if res:

+ 1 - 0
environment.yml

@@ -15,6 +15,7 @@ dependencies:
       - paddlehub
       - fastapi
       - uvicorn
+      - zhon
       - jinja2
       - aiofiles
       - python-multipart

+ 25 - 1
server.py

@@ -50,8 +50,21 @@ print(f'use gpu: {use_gpu}')
 #                 # 网络不够大、不够深
 #                 # 数据集普遍较小,batch size普遍较小
 #                 warmup=True)
+# ocr = PaddleOCR(use_angle_cls=True,
+#                 use_gpu=use_gpu)
+
 ocr = PaddleOCR(use_angle_cls=True,
-                use_gpu=use_gpu)
+                rec_model_dir='./ch_ppocr_server_v2.0_rec_infer',
+                det_model_dir='./ch_ppocr_server_v2.0_det_infer',
+                cls_model_dir='./idcard_cls_infer',
+                ocr_version='PP-OCRv2',
+                rec_algorithm='CRNN',
+                use_gpu=use_gpu,
+                det_db_unclip_ratio=2.5,
+                det_db_thresh=0.1,
+                det_db_box_thresh=0.3,
+                warmup=True)
+
 
 
 # 初始化 角度检测器 对象
@@ -90,6 +103,17 @@ def idcard(request: Request, id_card: IdCardInfo):
     return m.predict(image, id_card.image_type)
 
 
+@app.post("/ocr_system/orientation")
+@sxtimeit
+@web_try()
+# 传入=> base64码 -> np
+# 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
+def detect_angle(request: Request, id_card: IdCardInfo):
+    image = base64_to_np(id_card.image)
+    angle, _ = ad.detect_angle(image, id_card.image_type)
+    return {'orientation': angle}
+
+
 if __name__ == '__main__':
     import uvicorn
     import argparse

+ 1 - 1
testing/address_test.py

@@ -6,7 +6,7 @@ import requests
 
 
 
-url = 'http://192.168.199.208:18081'
+url = 'http://192.168.199.249:18081'
 
 
 def send_request(image_path, image_type):

+ 1 - 1
testing/id_test.py

@@ -6,7 +6,7 @@ import requests
 
 
 
-url = 'http://192.168.199.208:18081'
+url = 'http://192.168.199.249:18081'
 
 
 def send_request(image_path, image_type):

+ 1 - 1
testing/name_test.py

@@ -6,7 +6,7 @@ import requests
 
 
 
-url = 'http://192.168.199.208:18081'
+url = 'http://192.168.199.249:18081'
 
 
 def send_request(image_path, image_type):

+ 1 - 1
testing/nation_test.py

@@ -6,7 +6,7 @@ import requests
 
 
 
-url = 'http://192.168.199.208:18081'
+url = 'http://192.168.199.249:18081'
 
 
 def send_request(image_path, image_type):

+ 11 - 12
testing/orient_test.py

@@ -6,7 +6,7 @@ import requests
 
 
 
-url = 'http://192.168.199.208:18081'
+url = 'http://192.168.199.249:18081'
 
 
 def send_request(image_path, image_type):
@@ -18,21 +18,16 @@ def send_request(image_path, image_type):
 
 
 class TestIdCardAddress(unittest.TestCase):
-    def _helper(self, image_path, orient):
+    def _helper(self, image_path, orient, image_type='0'):
         root = Path(__file__).parent
         image_path = str(root / image_path)
-        r = send_request(image_path, '0')
+        r = send_request(image_path, image_type)
         self.assertEqual(orient, r['result']['orientation'], f'{image_path} orientation case error')
 
-    def _helper1(self, image_path, orient):
-        root = Path(__file__).parent
-        image_path = str(root / image_path)
-        r = send_request(image_path, '1')
-        self.assertEqual(orient, r['result']['orientation'], f'{image_path} orientation case error')
 
     def test_01_270(self):
         image_path = '../images/false/miss_orient/01_270.jpg'
-        self._helper(image_path, 0)
+        self._helper(image_path, 3)
 
     def test_02_270(self):
         image_path = '../images/false/miss_orient/02_270.jpg'
@@ -40,11 +35,11 @@ class TestIdCardAddress(unittest.TestCase):
 
     def test_04(self):
         image_path = '../images/false/miss_orient/04.jpg'
-        self._helper(image_path, 0)
+        self._helper(image_path, 0, '1')
 
     def test_05(self):
         image_path = '../images/false/miss_orient/05.jpg'
-        self._helper(image_path, 0)
+        self._helper(image_path, 3)
 
     def test_06(self):
         image_path = '../images/false/miss_orient/06.png'
@@ -52,7 +47,11 @@ class TestIdCardAddress(unittest.TestCase):
 
     def test_07(self):
         image_path = '../images/false/miss_orient/07.jpg'
-        self._helper(image_path, 0)
+        self._helper(image_path, 0, '1')
+
+    def test_08(self):
+        image_path = '../images/false/miss_orient/08.jpg'
+        self._helper(image_path, 2, '1')
 
 
 if __name__ == '__main__':

+ 1 - 1
testing/true_test.py

@@ -5,7 +5,7 @@ from pathlib import Path
 
 import requests
 
-url = 'http://192.168.199.208:18081'
+url = 'http://192.168.199.249:18081'
 
 
 def send_request(image_path, image_type):