123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- import copy
- import math
- import pickle
- import re
- from dataclasses import dataclass
- from typing import List
- import numpy as np
- # result 对象
- # box: np.ndarray
- # txt: str
- # conf: float
- @dataclass
- class OcrResult(object):
- box: np.ndarray
- txt: str
- conf: float
- def __hash__(self):
- return hash(repr(self))
- def __repr__(self):
- return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
- @property
- def ltrb(self):
- l, t = np.min(self.box, 0)
- r, b = np.max(self.box, 0)
- return [l, t, r, b]
- def lt(self):
- l, t = np.min(self.box, 0)
- return [l, t]
- @property
- def wh(self):
- l, t = self.ltrb[:2]
- r, b = self.ltrb[2:]
- return [r - l, b - t]
- @property
- def center(self):
- l, t = self.ltrb[:2]
- r, b = self.ltrb[2:]
- return [(r + l) / 2, (b + t) / 2]
- # 共需要识别6个字段
- # 4: 姓名 出生地 籍贯 身份证号(性别)
- # 1: 血型 信仰
- class RanksParser(object):
- def __init__(self, res: List):
- self.col_fields = None
- self.row_fields = None
- Or = []
- for r in res:
- # box: np.ndarray | txt: str | conf: float
- _r = OcrResult(np.array(r[0]), r[1][0], r[1][1])
- Or.append(_r)
- self.ocr_res = Or
- del_index = []
- for _r in Or:
- if '常驻' in _r.txt or '常住' in _r.txt or '人口' in _r.txt or '口登' in _r.txt or '记卡' in _r.txt:
- title_t = _r.ltrb[1] - self.mean_h_esp * 0.3
- title_b = _r.ltrb[3] + self.mean_h_esp * 0.3
- for _r_k, _r_v in enumerate(Or):
- if len(_r_v.txt) == 1 and _r_v == '美':
- del_index.append(_r_k)
- if title_t < _r_v.center[1] < title_b:
- del_index.append(_r_k)
- break
- for k, i in enumerate(del_index):
- del Or[i - k]
- self.ocr_res = Or
- @property
- def confidence(self):
- return np.mean([r.conf for r in self.ocr_res])
- @property
- def center_five_row(self):
- boxs = []
- for row in self.ocr_res:
- boxs.extend(row.box)
- boxs = np.stack(boxs)
- l, t = np.min(boxs, 0)
- r, b = np.max(boxs, 0)
- five = (b - t) / 5
- return [t+five, b-five]
- @property
- def mean_h_esp(self):
- """
- 由框的平均高度 获取esp
- """
- esps = [r.wh[1] for r in self.ocr_res]
- return np.mean(esps)
- @property
- def anchor(self) -> OcrResult:
- """
- 获取锚点:身份证号
- """
- pass
- @staticmethod
- def merger_or(or_a, or_b):
- """
- 合并两个框
- """
- or_a.txt += or_b.txt
- l, t = np.min(np.min([or_a.box, or_b.box], 0), 0)
- r, b = np.max(np.max([or_a.box, or_b.box], 0), 0)
- return OcrResult(np.array([[l, t], [r, t], [r, b], [l, b]]), or_a.txt, or_a.conf)
- # 获得 r 左上角的xy坐标
- def get_xy(self, r: OcrResult):
- return [r.ltrb[0], r.ltrb[1]]
- def all_required_fields(self, eps):
- """
- 获取所有需要的字段
- :return:
- """
- # 添加 位于 anc 右侧的 字段
- # 简单的合并两个OcrResult
- def get_one_line(anc, field):
- # 获取处于anc行 且 在field右侧 的字段
- result = []
- anc_field = None
- en = re.compile(u'[\u0041-\u005a+\u0061-\u007a]')
- for r in self.ocr_res:
- if '型' in r.txt and ''.join(re.findall(en, r.txt)):
- r.txt = '血型' + ''.join(re.findall(en, r.txt))[0] + '型'
- return r
- if '不' in r.txt and '不便' not in r.txt:
- r.txt = '血型不明'
- return r
- if anc_field is None and field in r.txt and r.ltrb[0] - anc.ltrb[2] > 0:
- # 在anc 右侧找到 标志txt
- anc_field = r
- result.append(anc_field)
- for j in self.ocr_res:
- if anc_field and abs(j.ltrb[1] - anc.ltrb[1]) < self.mean_h_esp * 0.8 and \
- (j.ltrb[0] - anc_field.ltrb[2] > 0 or abs(j.ltrb[0] - anc_field.ltrb[2]) < 15):
- # anc_field 找到后 & r不在result & r在anc同一行 & r在anc_field右侧
- result.append(j)
- if len(result) == 0:
- return
- for res in result[1:]:
- if res.txt in result[0].txt: continue
- result[0] = self.merger_or(result[0], res)
- box = result[0].box
- box[0][1] = anc.ltrb[1]
- box[1][1] = anc.ltrb[1]
- return OcrResult(box, result[0].txt, result[0].conf)
- def grt_row_up(anc: OcrResult):
- result = []
- for r in self.ocr_res:
- if "天主教" in r.txt:
- r.txt = '宗教信仰是天主教'
- return r
- if '无宗教' in r.txt or '无亲教' in r.txt:
- r.txt = '宗教信仰是无宗教信仰'
- return r
- if "无" in r.txt and len(r.txt) < 3 and anc.center[1] - r.center[1] > 0:
- r.txt = '宗教信仰是无'
- return r
- if anc and(abs(r.ltrb[0] - anc.ltrb[0]) < self.mean_h_esp * 0.8 and r.ltrb[1] - anc.ltrb[1] < 0):
- # 宗教信仰在上一行 在血型不能太
- result.append(r)
- if anc:
- if len(result) == 0:
- txt = '宗教信仰是'
- return OcrResult(anc.box, txt, anc.conf)
- for _r in result:
- if '宗' in _r.txt or '教' in _r.txt or '信' in _r.txt or '仰' in _r.txt:
- _r.txt = '宗教信仰是' + _r.txt
- return _r
- txt = '宗教信仰是'
- return OcrResult(anc.box, txt, anc.conf)
- # 剔除曾用名
- def del_former_name(fields):
- del fields[1]
- return fields[1]
- # 剔除常住人口登记卡附近火星文
- anchor_xy = self.get_xy(self.anchor)
- # 行
- row_fields = [self.anchor]
- for row in self.ocr_res:
- if abs(row.ltrb[1] - anchor_xy[1]) < eps and row.ltrb[0] - anchor_xy[0] > 0:
- row_fields.append(row)
- anchor_xy = self.get_xy(row)
- continue
- row_fields = sorted(row_fields, key=lambda x: [x.ltrb[1], x.ltrb[0]])
- # 列 - 姓名、出生地、籍贯
- col_fields = [self.anchor]
- for col in self.ocr_res:
- if '常' in col.txt or '住' in col.txt or '人' in col.txt or '口' in col.txt: continue
- if '农业家' in col.txt or '户' in col.txt or '姓名' in col.txt or '出生地' in col.txt or '居民' in col.txt: continue
- if ('户' in col.txt or '性' in col.txt or '民' in col.txt or '出' in col.txt) and abs(col.center[0] - self.anchor.ltrb[2]) < 20: continue
- # 处理 名xxx 错误
- if (self.anchor.ltrb[0] < col.center[0] < self.anchor.ltrb[2] or
- self.anchor.ltrb[0] < col.ltrb[2] < self.anchor.ltrb[2]) and \
- self.mean_h_esp * 12.6 > self.anchor.center[1] - col.center[1] > self.mean_h_esp * 3:
- col_fields.append(col)
- self.get_xy(col)
- if len(col.txt) < 5 and col.txt[0] == '名':
- col.txt = col.txt.split('名')[-1]
- col_fields.append(col)
- continue
- col_fields = sorted(col_fields, key=lambda x: [x.ltrb[1], x.ltrb[0]])
- # 检测曾用名
- if len(col_fields[1].txt) < 5:
- del_former_name(col_fields)
- # 整合住址
- col_fields = self.merge_address(col_fields)
- # 添加 血型(anchor ->身份证号)
- if get_one_line(self.anchor, '血'):
- col_fields.append(get_one_line(self.anchor, '血'))
- # 添加 宗教(anchor -> 身份证号)
- if grt_row_up(get_one_line(self.anchor, '血')):
- col_fields.append(grt_row_up(get_one_line(self.anchor, '血')))
- return [row_fields, col_fields]
- # 整合地址
- def merge_address(self, fields: List[OcrResult]):
- id_index = 0
- fields_on_id = []
- for r in range(len(fields)):
- code_val = re.findall("\d{10,18}", fields[r].txt)
- if len(code_val):
- id_index = r
- fields_on_id = fields[:id_index]
- break
- # 对兴安盟科尔做特殊处理
- for add in fields_on_id:
- if '兴安盟科尔' in add.txt:
- num = 1
- for add_k, add_v in enumerate(self.ocr_res):
- if '前' in add_v.txt:
- fields_on_id[num].txt = '内蒙古兴安盟科尔沁右翼前旗' + add_v.txt.split('旗')[-1]
- num += 1
- if '中' in add_v.txt:
- fields_on_id[num].txt = '内蒙古兴安盟科尔沁右翼中旗' + add_v.txt.split('旗')[-1]
- num += 1
- if '后' in add_v.txt:
- fields_on_id[num].txt = '内蒙古兴安盟科尔沁右翼后旗' + add_v.txt.split('旗')[-1]
- num += 1
- if num > 2: return fields_on_id + fields[id_index:]
- # 地址多行三种情况
- if len(fields_on_id) == 2:
- fields_on_id.append(fields_on_id[1])
- return fields_on_id + fields[id_index:]
- if len(fields_on_id) == 3:
- # 正常地址
- return fields_on_id + fields[id_index:]
- if len(fields_on_id) == 4:
- # 出生地多行 或者 籍贯多行
- if len(fields_on_id[-1].txt) < 7:
- # 籍贯多行
- fields_on_id[-2] = self.merger_or(fields_on_id[-2], fields_on_id[-1])
- del fields_on_id[-1]
- else:
- # 出生地多行
- fields_on_id[1] = self.merger_or(fields_on_id[1], fields_on_id[2])
- del fields_on_id[2]
- return fields_on_id + fields[id_index:]
- if len(fields_on_id) == 5:
- # 出生地 籍贯 都多行
- fields_on_id[1] = self.merger_or(fields_on_id[1], fields_on_id[2])
- fields_on_id[-2] = self.merger_or(fields_on_id[-2], fields_on_id[-1])
- del fields_on_id[2]
- del fields_on_id[-1]
- return fields_on_id + fields[id_index:]
- def parse(self, eps=5.0):
- """
- 解析所有字段
- :return:
- """
- ranks_fields: List = self.all_required_fields(eps)
- # ranks_fields = [sorted(ranks_fields[0], key=lambda x: [x.ltrb[1], x.ltrb[0]]),
- # sorted(ranks_fields[1], key=lambda x: [x.ltrb[1], x.ltrb[0]])]
- return ranks_fields
- def all_required_fields_f(self, eps):
- def merge_address(fields: List[OcrResult]):
- if len(fields) == 1:
- # 正常地址
- return fields
- if len(fields) == 2:
- # 两行地址
- box = fields[0].box
- txt = fields[0].txt + fields[1].txt
- conf = fields[0].conf
- fields[0] = (OcrResult(box, txt, conf))
- return fields
- def zero_ocr():
- return OcrResult(np.zeros((4, 2)), '', 0.)
- anchor_xy = self.get_xy(self.anchor)
- # 行 -> 住址
- row_fields = []
- for row in self.ocr_res:
- if len(row.txt) == 1 or ('住址' in row.txt and len(row.txt) < 7): continue
- if len(row_fields) == 0 and len(row.txt) < 6: continue
- if abs(row.ltrb[1] - anchor_xy[1]) < self.mean_h_esp * 1.5 \
- and row.center[0] - self.anchor.center[0] > self.anchor.wh[0] * 0.8:
- row_fields.append(row)
- # if self.anchor.ltrb[1] < row.center[1] < self.anchor.ltrb[3] \
- # and row.center[0] - self.anchor.center[0] > 0:
- # row_fields.append(row)
- if len(row_fields) == 0: row_fields.append(zero_ocr())
- merge_address(row_fields)
- # 列 -> 户别
- col_fields = []
- for col in self.ocr_res:
- if not col_fields and len(col.txt) == 1: continue
- # if abs(col.ltrb[0] - anchor_xy[0]) < eps + 40. and col.ltrb[1] - anchor_xy[1] < 0 \
- # and col.ltrb[2] > self.anchor.ltrb[0]:
- if self.anchor.ltrb[0] < col.center[0] < self.anchor.ltrb[2] \
- and self.anchor.ltrb[1] > col.center[1] > self.mean_h_esp * 1.5:
- col_fields.append(col)
- if not col_fields: col_fields.append(zero_ocr())
- return [row_fields, col_fields]
- def parse_f(self, eps=5.0):
- """
- 解析所有字段
- :return:
- """
- return False if self.anchor is False else self.all_required_fields_f(eps)
- # 子类 常驻人口页0
- class PeopleParser(RanksParser):
- def __init__(self, res: OcrResult):
- super(PeopleParser, self).__init__(res)
- @property
- def anchor(self):
- for r in self.ocr_res:
- txt = r.txt
- if "X" in txt or "x" in txt:
- code_val = re.findall("\d*[X|x]", txt)
- else:
- code_val = re.findall("\d{10,18}", txt)
- if len(code_val) > 0 and (len(code_val[0]) == 18 or len(code_val[0]) > 10) and \
- self.center_five_row[0] < r.center[1] < self.center_five_row[1]:
- return OcrResult(r.box, code_val[0], r.conf)
- raise Exception("没有找到身份证号")
- def parse(self, eps=5.0):
- return super(PeopleParser, self).parse(eps)
- # 子类 户口本首页1
- class FrontParser(RanksParser):
- def __init__(self, res: OcrResult):
- super(FrontParser, self).__init__(res)
- @property
- def anchor(self): # sourcery skip: merge-nested-ifs, reintroduce-else, remove-redundant-continue
- res = self.ocr_res
- code_val = []
- for r in res:
- txt = r.txt.replace('-', '')
- if bool(re.findall("\d{5,12}", txt)):
- # if bool(re.findall(u"[\u4e00-\u9fa5]", txt)): continue
- code_val.append(r)
- if len(code_val) == 1:
- return code_val[0]
- elif len(code_val) > 1:
- return code_val[1]
- # 空间解析失败 换字符串解析
- else:
- return False
- def parse_f(self, eps=5.0):
- return super(FrontParser, self).parse_f(eps)
|