Zhang Li 2 rokov pred
rodič
commit
65a4f20944
3 zmenil súbory, kde vykonal 253 pridanie a 208 odobranie
  1. 103 0
      core/line_parser.py
  2. 11 5
      core/ocr.py
  3. 139 203
      core/parser.py

+ 103 - 0
core/line_parser.py

@@ -0,0 +1,103 @@
+import numpy as np
+from dataclasses import dataclass
+
+# result 对象
+@dataclass
+class OcrResult(object):
+    box: np.ndarray
+    txt: str
+    conf: float
+
+    def __hash__(self):
+        return hash(repr(self))
+
+    def __repr__(self):
+        return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
+
+    @property
+    def lt(self):
+        l, t = np.min(self.box, 0)
+        return [l, t]
+
+    @property
+    def rb(self):
+        r, b = np.max(self.box, 0)
+        return [r, b]
+
+    @property
+    def wh(self):
+        l, t = self.lt
+        r, b = self.rb
+        return [r - l, b - t]
+
+    def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
+        if is_horizontal:
+            return abs(self.lt[1] - b.lt[1]) < eps
+        else:
+            return abs(self.rb[0] - b.rb[0]) < eps
+
+
+# 行处理器
+class LineParser(object):
+    def __init__(self, ocr_raw_result):
+        self.ocr_res = []
+        for re in ocr_raw_result:
+            o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
+            self.ocr_res.append(o)
+        self.eps = self.avg_height * 0.66
+
+    @property
+    def is_horizontal(self):
+        res = self.ocr_res
+        wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
+        return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
+
+    @property
+    def avg_height(self):
+        idx = self.is_horizontal + 0
+        return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
+
+    @property
+    def confidence(self):
+        return np.mean([r.conf for r in self.ocr_res])
+
+    # 处理器函数
+    def parse(self, eps=40.0):
+        # 存返回值
+        res = []
+
+        # 需要 处理的 OcrResult 对象  的长度
+        length = len(self.ocr_res)
+
+        # 如果字段数 小于等于1 就抛出异常
+        if length <= 1:
+            raise Exception('无法识别')
+
+        # 遍历数组 并处理他
+        for i in range(length):
+            # 拿出 OcrResult对象的 第i值 -暂存-
+            res_i = self.ocr_res[i]
+
+            # any:-> True
+            #       -input: 可迭代对象   |   -output: bool
+            #       -如果iterable的任何元素为true,则返回true。如果iterable为空,则返回false。 -与🚪-
+            # map: -> [False, False, False, False, True, True, False, False]
+            #       -input: (函数, 可迭代对象)     |    -output: 可迭代对象
+            #       -把 res 喂给lambda --lambda返回True的值-->  输出 新的可迭代对象
+
+            # 这次的 res_i 之前已经在结果集中,就继续下一个
+            if any(map(lambda x: res_i in x, res)): continue
+
+            # set() -> {}
+            # 初始化一个集合 即-输出-
+            res_row = set()
+
+            for j in range(i, length):
+                res_j = self.ocr_res[j]
+                if res_i.one_line(res_j, self.is_horizontal, self.eps):
+                    # LineParser 对象  不可以直接加入字典
+
+                    res_row.add(res_j)
+            res.append(res_row)
+        idx = self.is_horizontal + 0
+        return sorted([list(r) for r in res], key=lambda x: x[0].lt[idx])

+ 11 - 5
core/ocr.py

@@ -1,4 +1,6 @@
 from dataclasses import dataclass
+
+from core.line_parser import LineParser
 from core.parser import *
 from core.direction import *
 import numpy as np
@@ -15,13 +17,17 @@ class IdCardOcr:
         txts, confs, result = self._ocr(image)
         angle = self.angle_detector.detect_angle(image, result)
 
+        line_parser = LineParser(result)
+        line_result = line_parser.parse()
+        conf = line_parser.confidence
+
         if int(image_type) == 0:
-            parser = FrontParser(txts, confs)
+            parser = FrontParser(line_result)
         elif int(image_type) == 1:
-            parser = BackParser(txts, confs)
+            parser = BackParser(line_result)
         else:
             raise Exception('无法识别')
-        return self._post_process(angle, parser, image_type)
+        return self._post_process(conf, angle, parser, image_type)
 
     def _pre_process(self, image) -> (np.ndarray, int):
         angle = detect_angle(image)
@@ -50,9 +56,9 @@ class IdCardOcr:
         # print("......................................")
         return txts, confs, result
 
-    def _post_process(self, angle: int, parser: Parser, image_type: str):
+    def _post_process(self, conf, angle: int, parser: Parser, image_type: str):
         ocr_res = parser.parse()
-        conf = parser.confidence
+
 
         res = {
             "confidence": conf,

+ 139 - 203
core/parser.py

@@ -4,6 +4,9 @@ from dataclasses import dataclass
 from collections import defaultdict
 import numpy as np
 import cpca
+from typing import List
+
+from core.line_parser import OcrResult
 
 
 @dataclass
@@ -16,10 +19,8 @@ class RecItem:
 
 
 class Parser(object):
-    def __init__(self, txts, confs):
-        self.result = txts
-        self.confs = confs
-        assert len(self.result) == len(self.confs), 'result and confs do not match'
+    def __init__(self, ocr_results: List[OcrResult]):
+        self.result = ocr_results
         self.res = defaultdict(RecItem)
         self.keys = ["name", "id", "ethnicity", "gender", "birthday",
                      "address", "address_province", "address_city", "address_region", "address_detail", "expire_date"]
@@ -30,254 +31,189 @@ class Parser(object):
     def parse(self):
         return self.res
 
-    @property
-    def confidence(self):
-        return 0.
-
 
 class FrontParser(Parser):
     """
     """
 
-    def __init__(self, txts, confs):
-        Parser.__init__(self, txts, confs)
-        self.result = [
-            i.replace(" ", "").translate(str.maketrans("", "", string.punctuation))
-            for i in txts
-        ]
-        assert len(self.result) == len(self.confs), 'result and confs do not match'
+    def __init__(self, ocr_results: List[OcrResult]):
+        Parser.__init__(self, ocr_results)
 
     def birth(self):
         if len(self.res["id"].text) == 18:
             # 342423 2001  0  2    1  5    6552
             # 012345 6789  10 11   12 13   14
             str_num = self.res["id"].text
-            date = list(str_num[6:10] + "年" + str_num[10:12] + "月" + str_num[12:14] + "日")
-            if date[date.index("月") - 2] == "0":
-                del date[date.index("月") - 2]
-            if date[date.index("日") - 2] == "0":
-                del date[date.index("日") - 2]
-            self.res["birthday"].text = "".join(date)
+            date = str_num[6:10] + "年" + str_num[10:12] + "月" + str_num[12:14] + "日"
+            self.res["birthday"] = RecItem(date, self.res['id'].confidence)
 
     def card_no(self):
         """
         身份证号码
         """
-        for i in range(len(self.result)):
-            txt = self.result[i]
-
-            # 身份证号码
-            if "X" in txt or "x" in txt:
-                res = re.findall("\d*[X|x]", txt)
-            else:
-                res = re.findall("\d{16,18}", txt)
+        for idx, row in enumerate(self.result):
+            for r in row:
+                txt = r.txt
 
-            if len(res) > 0:
-                if len(res[0]) == 18:
-                    self.res["id"].text = res[0].replace("号码", "")
-                    self.res["id"].confidence = self.confs[i]
-                    self.res["gender"].text = "男" if int(res[0][16]) % 2 else "女"
-                    self.res["gender"].confidence = self.confs[i]
-                break
-
-    def full_name(self):
-        """
-        身份证姓名
-        """
-        for i in range(len(self.result)):
-            txt = self.result[i]
-            length = len(txt)
-            if "姓名" in txt:
-                if len(txt) < 7:
-                    res = re.findall("姓名[\u4e00-\u9fa5]{1,4}", txt)
-                    # 三个字名字
-                    if len(res) > 0:
-                        self.res["name"].text = res[0].split("姓名")[-1]
-                        self.res["name"].confidence = self.confs[i]
-                        self.result[i] = "temp"  # 避免身份证姓名对地址造成干扰
-                        break
+                # 身份证号码
+                if "X" in txt or "x" in txt:
+                    res = re.findall("\d*[X|x]", txt)
                 else:
-                    res = txt[2:]
-                    name_list = []
-                    point_unicode = ["\u2E31", "\u2218", "\u2219", "\u22C5", "\u25E6", "\u2981",
-                                     "\u00B7", "\u0387", "\u05BC", "\u16EB", "\u2022", "\u2027",
-                                     "\u2E30", "\uFF0E", "\u30FB", "\uFF65", "\u10101"]
-                    for n in range(len(point_unicode)):
-                        point = re.findall(point_unicode[n], res)
-                        if len(point) != 0:
-                            name_list = res.split(point[0])
-                            for m in range(len(name_list)):
-                                name_list[m] = name_list[m].replace(' ', '')
-                            res = name_list[0] + '\u00B7' + name_list[1]
-
-                self.res["name"].text = res
-                self.res["name"].confidence = self.confs[i]
-                self.result[i] = "temp"  # 避免身份证姓名对地址造成干扰
-
-    def gender(self):
+                    res = re.findall("\d{16,18}", txt)
+
+                if len(res) > 0:
+                    if len(res[0]) == 18:
+                        self.res["id"].text = res[0]
+                        self.res["id"].confidence = r.conf
+                        self.res["gender"].text = "男" if int(res[0][16]) % 2 else "女"
+                        self.res["gender"].confidence = r.conf
+                        if idx < 2:
+                            self.result = self.result[idx + 1:]
+                            self.result.reverse()
+                        else:
+                            self.result = self.result[:idx]
+                    return
+        raise Exception('无法识别')
+
+    def name(self):
         """
-        性别女民族汉
+        姓名
         """
-        if len(self.res["gender"].text) != 0: return
-        for i in range(len(self.result)):
-            txt = self.result[i]
-            if "男" in txt:
-                self.res["gender"] = RecItem("男", self.confs[i])
-                break
 
-            if "女" in txt:
-                self.res["gender"] = RecItem("女", self.confs[i])
-                break
+        if len(self.result[0]) == 2:
+            for r in self.result[0]:
+                if '姓' in r.txt or ('名' in r.txt and len(r.txt) < 3):
+                    continue
+                else:
+                    self.res['name'] = RecItem(r.txt, r.conf)
+                    return
+        if len(self.result[0]) == 1:
+            txt = self.result[0][0].txt
+            conf = self.result[0][0].conf
+            if "姓名" in txt:
+
+                res = txt[2:]
+                name_list = []
+                point_unicode = ["\u2E31", "\u2218", "\u2219", "\u22C5", "\u25E6", "\u2981",
+                                 "\u00B7", "\u0387", "\u05BC", "\u16EB", "\u2022", "\u2027",
+                                 "\u2E30", "\uFF0E", "\u30FB", "\uFF65", "\u10101"]
+                for n in range(len(point_unicode)):
+                    point = re.findall(point_unicode[n], res)
+                    if len(point) != 0:
+                        name_list = res.split(point[0])
+                        self.res['name'] = RecItem(name_list[0].replace('姓名') + '\u00B7' + name_list[1], conf)
+                        return
+
+                res = re.findall("姓名[\u4e00-\u9fa5]{1,7}", txt)
+                if len(res) > 0:
+                    self.res["name"] = RecItem(res[0].split("姓名")[-1], conf)
+                    return
+        raise Exception('无法识别')
 
     def national(self):
         # 性别女民族汉
-        for i in range(len(self.result)):
-            txt = self.result[i]
+        if len(self.result[1]) == 1:
+            txt = self.result[1][0].txt
+            conf = self.result[1][0].conf
             res = re.findall(".*民族[\u4e00-\u9fa5]+", txt)
 
             if len(res) > 0:
-                self.res["ethnicity"] = RecItem(res[0].split("族")[-1], self.confs[i])
-                break
+                self.res["ethnicity"] = RecItem(res[0].split("族")[-1], conf)
+                return
 
     def address(self):
         """
         身份证地址
         """
-        addString = []
-        conf = []
-        for i in range(len(self.result)):
-            txt = self.result[i]
-            txt = txt.replace("号码", "")
-            if "公民" in txt:
-                txt = "temp"
-            # 身份证地址
-
-            if (
-                    "住址" in txt
-                    or "址" in txt
-                    or "省" in txt
-                    or "市" in txt
-                    or "县" in txt
-                    or "街" in txt
-                    or "乡" in txt
-                    or "村" in txt
-                    or "镇" in txt
-                    or "区" in txt
-                    or "城" in txt
-                    or "组" in txt
-                    or "旗" in txt
-                    or "号" in txt
-            ):
-                # if "住址" in txt or "省" in txt or "址" in txt:
-                if "住址" in txt or "省" in txt or "址" in txt or \
-                        ('市' in txt and len(addString) > 0 and '市' not in addString[0]):
-                    addString.insert(0, txt.split("址")[-1])
-                else:
-                    addString.append(txt)
-                conf.append(self.confs[i])
-                self.result[i] = "temp"
-        if len(addString) > 0:
-            self.res["address"].text = "".join(addString)
-            self.res["address"].confidence = np.mean(conf)
-        # print(f'addr: {self.res["Address"]}')
+        res = []
+        confs = []
 
-    def split_addr(self):
-        if self.res["address"].text:
-            conf = self.res["address"].confidence
-            df = cpca.transform([self.res["address"].text])
-            # print(df)
-
-            province = df.iloc[0, 0]
-            city = df.iloc[0, 1]
-            region = df.iloc[0, 2]
-            detail = df.iloc[0, 3]
-            print(f'pronvince: {province}, city: {city}, region: {region}, detail: {detail}')
-            self.res["address_province"] = RecItem(province, conf)
-            self.res["address_city"] = RecItem(city, conf)
-            if detail and "旗" in detail:
-                temp_region = []
-                temp_region.insert(0, detail.split("旗")[0] + "旗")
-                self.res["address_region"] = RecItem(temp_region[0], conf)
-                self.res["address_detail"] = RecItem(detail.split("旗")[-1], conf)
-            else:
-                self.res["address_region"] = RecItem(region, conf)
-                self.res["address_detail"] = RecItem(detail, conf)
+        for row in self.result[3:]:
+            for r in row:
+                txt = r.txt
 
-    def expire_date(self):
-        for txt, conf in zip(self.result, self.confs):
-            txt = txt.replace('.', '')
-            res = re.findall('\d{8}\-\d{8}', txt)
-            if res:
-                self.res["expire_date"] = RecItem(res[0], conf)
-                break
-            res = re.findall('\d{8}\-长期', txt)
-            if res:
-                self.res["expire_date"] = RecItem(res[0], conf)
-                break
-
-    def predict_name(self):
-        """
-        如果PaddleOCR返回的不是姓名xx连着的,则需要去猜测这个姓名,此处需要改进
-        """
-        if len(self.res['name'].text) > 1: return
-        for i in range(len(self.result)):
-            txt = self.result[i]
-            if 1 < len(txt) < 5:
                 if (
-                        "性别" not in txt
-                        and "姓名" not in txt
-                        and "民族" not in txt
-                        and "住址" not in txt
-                        and "出生" not in txt
-                        and "号码" not in txt
-                        and "身份" not in txt
+                        "住址" in txt
+                        or "址" in txt
+                        or "省" in txt
+                        or "市" in txt
+                        or "县" in txt
+                        or "街" in txt
+                        or "乡" in txt
+                        or "村" in txt
+                        or "镇" in txt
+                        or "区" in txt
+                        or "城" in txt
+                        or "组" in txt
+                        or "旗" in txt
+                        or "号" in txt
                 ):
-                    result = re.findall("[\u4e00-\u9fa5]{2,4}", txt)
-                    if len(result) > 0:
-                        self.res["Name"] = RecItem(result[0], self.confs[i])
-                        break
+                    # if "住址" in txt or "省" in txt or "址" in txt:
+                    if "住址" in txt or "址" in txt:
+                        res.append(txt.split("址")[-1])
+                    else:
+                        res.append(txt)
+                    confs.append(r.conf)
+
+        if len(res) > 0:
+            self.res["address"] = RecItem("".join(res), np.mean(confs))
+            self.split_addr()
+            return
+        raise Exception('无法识别')
 
-    @property
-    def confidence(self):
-        return np.mean(self.confs)
+    def split_addr(self):
+        conf = self.res["address"].confidence
+        df = cpca.transform([self.res["address"].text])
+        # print(df)
+
+        province = df.iloc[0, 0]
+        city = df.iloc[0, 1]
+        region = df.iloc[0, 2]
+        detail = df.iloc[0, 3]
+        print(f'pronvince: {province}, city: {city}, region: {region}, detail: {detail}')
+        self.res["address_province"] = RecItem(province, conf)
+        self.res["address_city"] = RecItem(city, conf)
+        if detail and "旗" in detail:
+            temp_region = []
+            temp_region.insert(0, detail.split("旗")[0] + "旗")
+            self.res["address_region"] = RecItem(temp_region[0], conf)
+            self.res["address_detail"] = RecItem(detail.split("旗")[-1], conf)
+        else:
+            self.res["address_region"] = RecItem(region, conf)
+            self.res["address_detail"] = RecItem(detail, conf)
+        if not self.res['address_region'].text or not self.res['address_detail'].text:
+            raise Exception('无法识别')
 
     def parse(self):
-        self.full_name()
-        self.national()
         self.card_no()
-        self.address()
-        self.split_addr()
+        self.name()
+        self.national()
         self.birth()
-        self.gender()
-        self.expire_date()
-        self.predict_name()
-        if not self.res["id"].text:
-            raise Exception("没有识别到身份证号")
+        self.address()
         return {key: self.res[key].to_dict() for key in self.keys}
 
 
 class BackParser(Parser):
-    def __init__(self, txts, confs):
-        Parser.__init__(self, txts, confs)
+    def __init__(self, ocr_results: List[OcrResult]):
+        Parser.__init__(self, ocr_results)
 
     def expire_date(self):
-        for txt, conf in zip(self.result, self.confs):
-            txt = txt.replace('.', '')
-            res = re.findall('\d{8}\-\d{8}', txt)
-            if res:
-                self.res["expire_date"] = RecItem(res[0], conf)
-                break
-            res = re.findall('\d{8}\-长期', txt)
-            if res:
-                self.res["expire_date"] = RecItem(res[0], conf)
-                break
-
-    @property
-    def confidence(self):
-        return np.mean(self.confs)
+        for row in self.result:
+            for r in row:
+                txt = r.txt
+                txt = txt.replace('.', '')
+                res = re.findall('\d{8}\-\d{8}', txt)
+                if res:
+                    self.res["expire_date"] = RecItem(res[0], r.conf)
+                    return
+                res = re.findall('\d{8}\-长期', txt)
+                if res:
+                    self.res["expire_date"] = RecItem(res[0], r.conf)
+                    return
+        raise Exception('无法识别')
 
     def parse(self):
         self.expire_date()
         if not self.res["expire_date"].text:
             raise Exception("无法识别")
-        return {key: self.res[key].to_dict() for key in self.keys}
+        return {key: self.res[key].to_dict() for key in self.keys}