Browse Source

update detect model & rec model

Raychar 2 years ago
parent
commit
2c36bfc8d8
4 changed files with 140 additions and 112 deletions
  1. 12 4
      core/direction.py
  2. 53 10
      core/line_parser.py
  3. 75 24
      core/ocr.py
  4. 0 74
      testing/true_test.py

+ 12 - 4
core/direction.py

@@ -3,6 +3,8 @@ from dataclasses import dataclass
 from paddleocr import PaddleOCR
 
 import numpy as np
+import imutils
+import matplotlib.pyplot as plt
 
 
 def detect_angle(result, ocr_anchor: OcrAnchor):
@@ -12,6 +14,7 @@ def detect_angle(result, ocr_anchor: OcrAnchor):
     print(res)
     print('------ angle ocr -------')
     is_horizontal = lp.is_horizontal
+    # rotate_angle = lp.is_need_rotate
     return ocr_anchor.locate_anchor(res, is_horizontal)
 
 
@@ -33,14 +36,19 @@ class AngleDetector(object):
         except Exception as e:
             print("direction.py这里有异常。。。。。。")
             print(e)
-            # 如果第一次识别不到,旋转90度再识别
-            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
-            result = self.ocr.ocr(img, cls=True)
+            # 如果第一次识别不到,再识别
+            result = self.ocr.ocr(img, cls=False)
             angle = detect_angle(result, ocr_anchor)
             # 旋转90度之后要重新计算角度
-            return (angle - 1 + 4) % 4, result
+            # return (angle - 1 + 4) % 4, result
+            return angle, result
 
     def origin_detect(self, img):
         # 这边一般是在自己的检测模型result=[]时,再使用官方的模型做个检测,如果这个也没有结果,那就真的检测不出来
         result = self.ocr.ocr(img)
         return result
+
+    def det_oneline(self, result):
+        # 这边已经是转正之后的图片,不需要考虑是否水平,只要检测是否一行
+        lp = LineParser(result)
+        return lp.detection_parse()

+ 53 - 10
core/line_parser.py

@@ -1,5 +1,6 @@
-from dataclasses import dataclass
+import math
 import numpy as np
+from dataclasses import dataclass
 
 
 # result 对象
@@ -31,6 +32,24 @@ class OcrResult(object):
         r, b = self.rb
         return [r - l, b - t]
 
+    @property
+    def area(self):
+        w, h = self.wh
+        return w * h
+
+    @property
+    def is_slope(self):
+        """
+        function: 10~60,-60~-10度之间,需要旋转图片,因为目前的检测模型对于倾斜角度的不能检测
+        return: 需要旋转的角度 ---> tan
+        """
+        p0 = self.box[0]
+        p1 = self.box[1]
+        if p0[0] == p1[0]:  # 如果是正常的那就不用转
+            return 0
+        slope = 1. * (p1[1] - p0[1]) / (p1[0] - p0[0])
+        return slope
+
     @property
     def center(self):
         l, t = self.lt
@@ -42,19 +61,31 @@ class OcrResult(object):
         x_idx = 1 - y_idx
         if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
         if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
-        eps = 0.5 * (self.wh[y_idx] + b.wh[y_idx])
+        eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
         dist = abs(self.center[y_idx] - b.center[y_idx])
         return dist < eps
 
 
 # 行处理器
 class LineParser(object):
-    def __init__(self, ocr_raw_result):
+    def __init__(self, ocr_raw_result, filters=None):
+        # self.rotate_angle = 0
+        if filters is None:
+            filters = [lambda x: x.is_slope]
         self.ocr_res = []
         for re in ocr_raw_result:
             o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
+            # if any([f(o) for f in filters]): continue
             self.ocr_res.append(o)
-        self.eps = self.avg_height * 0.66
+        self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
+
+        # 找到最大的检测框,大概率就是卡号所在位置
+        # max_res = self.ocr_res[0]
+        # for f in filters:
+        #     k = f(max_res)
+        #     self.rotate_angle = math.atan(k) * 180 / math.pi
+
+        self.eps = self.avg_height * 0.7
 
     @property
     def is_horizontal(self):
@@ -62,11 +93,16 @@ class LineParser(object):
         wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
         return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
 
+    # @property
+    # def is_need_rotate(self):
+    #     return self.rotate_angle
+
     @property
     def avg_height(self):
         idx = self.is_horizontal + 0
         return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
 
+    # 整体置信度
     @property
     def confidence(self):
         return np.mean([r.conf for r in self.ocr_res])
@@ -79,15 +115,12 @@ class LineParser(object):
         # 需要 处理的 OcrResult 对象  的长度
         length = len(self.ocr_res)
 
-        # 如果字段数 小于等于1 就抛出异常
-        # if length <= 1:
-        #     raise Exception('无法识别')
-
         # 遍历数组 并处理他
         for i in range(length):
             # 拿出 OcrResult对象的 第i值 -暂存-
             res_i = self.ocr_res[i]
 
+            # 这次的 res_i 之前已经在结果集中,就继续下一个
             if any(map(lambda x: res_i in x, res)): continue
 
             # set() -> {}
@@ -96,10 +129,20 @@ class LineParser(object):
 
             for j in range(i, length):
                 res_j = self.ocr_res[j]
+                # 这次的 res_i 之前已经在结果集中,就继续下一个
+                if any(map(lambda x: res_j in x, res)): continue
+
                 if res_i.one_line(res_j, self.is_horizontal, self.eps):
-                    if any(map(lambda x: res_j in x, res)): continue
                     # LineParser 对象  不可以直接加入字典
+
                     res_row.add(res_j)
             res.append(res_row)
         idx = self.is_horizontal + 0
-        return sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
+        res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
+
+        return res
+
+    def detection_parse(self, eps=40.0):
+        result = self.ocr_res
+        if len(result) == 2:
+            return result[0].one_line(result[1], True, self.eps)

+ 75 - 24
core/ocr.py

@@ -1,6 +1,8 @@
 from dataclasses import dataclass
 
+import cv2
 import numpy as np
+import math
 from paddleocr import PaddleOCR, draw_ocr
 
 from core.direction import *
@@ -20,7 +22,6 @@ class BankOcr:
         print(f'---------- detect angle: {angle} 角度 --------')
         # 这里使用自己训练的检测识别模型,在此之前,理想情况下,所有的银行卡的角度都已经是0,(正向)
         _, _, result = self._ocr(image)
-
         # self.imshow(image, result)  # 将检测图片保存
         return self._post_process(result, angle)
 
@@ -38,13 +39,28 @@ class BankOcr:
 
         if angle == 1:
             image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
-        # print("检测出来的角度:", angle)  # 逆时针
         if angle == 2:
             image = cv2.rotate(image, cv2.ROTATE_180)
         if angle == 3:
             image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
 
-        return image, angle, result
+        # if -60 <= rotate_angle <= -20 or 20 <= rotate_angle <= 60:
+        #     print("需要旋转角度")
+        #     image = imutils.rotate(image, rotate_angle)
+
+        # 因为有些img像素过大,导致检测框效果不好,识别就会出问题
+        h, w, _ = image.shape
+        h_ratio = 1 if h <= 1000 else h / 1000
+        w_ratio = 1 if w <= 1000 else w / 1000
+
+        if h_ratio == 1 and w_ratio == 1:
+            return image, angle, result
+        elif h_ratio != 1 or w_ratio != 1:
+            ratio = h_ratio if h_ratio > w_ratio else w_ratio
+            image = cv2.resize(image, (w // math.ceil(ratio), h // math.ceil(ratio)))
+            print(image.shape)
+
+            return image, angle, result
 
     def _ocr(self, image):
         # 获取模型检测结果,因为是正的照片了,所以不需要方向分类器
@@ -69,29 +85,64 @@ class BankOcr:
         if result:
             confs = [line[1][1] for line in result]
             print("自己的检测模型得到的conf:", confs)
-            if len(result) == 2 and all(map(lambda x: x > 0.975, confs)):
-                l_box, r_box = [], []
-                l_box.extend(result[0][0])
-                r_box.extend(result[1][0])
-
-                l_max, _ = np.max(l_box, 0)
-                r_min, _ = np.min(r_box, 0)
 
-                if l_max > r_min:
-                    print("说明自己的检测模型不好")
-                    result = self.angle_detector.origin_detect(image)
-            else:
-                # 一般情况下,len=1
-                flag = 0
-                if map(lambda x: x >= 0.975, confs):
-                    flag = 1
-                # for conf in confs:
-                #     if conf >= 0.975:
-                #         flag = 1
-                #         break
-                if flag == 0:
-                    print("需要再次进行官方的检测代码。。。。。。。。。。。。")
+            # 根绝len(result)分规则判断
+            if len(result) == 1:
+                if confs[0] > 0.987:
+                    txts = [line[1][0] for line in result]
+                    return txts, confs, result
+                else:
+                    print("len(result)=1时,再次用官方代码检测。。。。。。")
                     result = self.angle_detector.origin_detect(image)
+            elif len(result) == 2:
+                # 1.判断两个检测框在不在一行
+                is_oneline = self.angle_detector.det_oneline(result)
+                # 2.如果不在一行
+                if not is_oneline:
+                    txts = [line[1][0] for line in result]
+                    if not (any(map(lambda x: x > 0.987, confs)) and len(re.findall('\d{16,20}', txts)) > 0):
+                        print("len(result)=2,但是不在一行。。。。。。")
+                        result = self.angle_detector.origin_detect(image)
+                # 3. 如果在一行
+                elif is_oneline:
+                    if all(map(lambda x: x > 0.987, confs)):
+                        l_box, r_box = [], []
+                        l_box.extend(result[0][0])
+                        r_box.extend(result[1][0])
+
+                        l_max, _ = np.max(l_box, 0)
+                        r_min, _ = np.min(r_box, 0)
+
+                        if l_max > r_min:
+                            print("len(result)=2,在一行,但有重叠。。。。。。")
+                            result = self.angle_detector.origin_detect(image)
+                    else:
+                        print("len(result)=2,在一行,但有一个检测不行。。。。。。")
+                        result = self.angle_detector.origin_detect(image)
+            elif len(result) > 2:
+                print("len(result)=3,直接换官方检测。。。。。。")
+                result = self.angle_detector.origin_detect(image)
+
+            # elif len(result) == 2 and all(map(lambda x: x > 0.975, confs)):
+            #     l_box, r_box = [], []
+            #     l_box.extend(result[0][0])
+            #     r_box.extend(result[1][0])
+            #
+            #     l_max, _ = np.max(l_box, 0)
+            #     r_min, _ = np.min(r_box, 0)
+            #
+            #     if l_max > r_min:
+            #         print("说明自己的检测模型不好")
+            #         result = self.angle_detector.origin_detect(image)
+            # else:
+            #     # 一般情况下,len=1
+            #     flag = 0
+            #     if all(map(lambda x: x >= 0.975, confs)):
+            #         flag = 1
+            #
+            #     if flag == 0:
+            #         print("需要再次进行官方的检测代码。。。。。。。。。。。。")
+            #         result = self.angle_detector.origin_detect(image)
 
         # 如果还是空,那就检测不出来
         if not result:

+ 0 - 74
testing/true_test.py

@@ -1,74 +0,0 @@
-import unittest
-from pathlib import Path
-
-
-from testing.utils import send_request
-
-
-class TestBankCardOcr(unittest.TestCase):
-
-    def _helper(self, image_path, sta, orient, card_no):
-        root = Path(__file__).parent
-        image_path = str(root / image_path)
-        r = send_request(image_path)
-        self.assertEqual(sta, r['status'], f'{image_path} status case error')
-        self.assertEqual(orient, r['result']['orientation'], f'{image_path} orientation case error')
-        self.assertEqual(card_no, r['result']['number']['text'], f'{image_path} number case error')
-
-    def test_true_t01(self):
-        image_path = '../images/ture/t01.png'
-        self._helper(image_path, '000', 1, '6217002580007133039')
-
-    def test_true_t02(self):
-        image_path = '../images/ture/t02.png'
-        self._helper(image_path, '000', 0, '6217000410005833061')
-
-    def test_true_t03(self):
-        image_path = '../images/ture/t03.png'
-        self._helper(image_path, '000', 0, '6217000940023315733')
-
-    def test_true_t04(self):
-        image_path = '../images/ture/t04.png'
-        self._helper(image_path, '000', 1, '6214835665420657')
-
-    def test_true_t05(self):
-        image_path = '../images/ture/t05.png'
-        self._helper(image_path, '000', 0, '6217000780063553227')
-
-    # 以前正确,现在方向错误
-    def test_true_t06(self):
-        image_path = '../images/ture/t06.png'
-        self._helper(image_path, '000', 0, '6230580000168512874')
-
-    def test_true_t07(self):
-        image_path = '../images/ture/t07.png'
-        self._helper(image_path, '000', 0, '6216618401001365345')
-
-    def test_true_t08(self):
-        image_path = '../images/ture/t08.jpg'
-        self._helper(image_path, '000', 0, '6217000416005548153')
-
-    def test_true_t09(self):
-        image_path = '../images/ture/t09.jpg'
-        self._helper(image_path, '000', 0, '6217000416005548146')
-
-    def test_true_t10(self):
-        image_path = '../images/ture/t10.jpg'
-        self._helper(image_path, '000', 0, '6217000416000737652')
-
-    def test_true_t11(self):
-        image_path = '../images/ture/t11.png'
-        self._helper(image_path, '000', 0, '6214834723639358')
-
-    def test_true_t12(self):
-        image_path = '../images/ture/t12.jpg'
-        self._helper(image_path, '000', 0, '6228480878204631674')
-
-    def test_true_05(self):
-        image_path = '../images/ture/05.jpg'
-        self._helper(image_path, '000', 0, '6217000416004473577')
-
-    def test_true_08(self):
-        image_path = '../images/ture/08.png'
-        self._helper(image_path, '000', 0, '6216618401001365345')
-