123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- import math
- import numpy as np
- from dataclasses import dataclass
- # result 对象
- @dataclass
- class OcrResult(object):
- box: np.ndarray
- txt: str
- conf: float
- def __hash__(self):
- return hash(repr(self))
- def __repr__(self):
- return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
- @property
- def lt(self):
- l, t = np.min(self.box, 0)
- return [l, t]
- @property
- def rb(self):
- r, b = np.max(self.box, 0)
- return [r, b]
- @property
- def wh(self):
- l, t = self.lt
- r, b = self.rb
- return [r - l, b - t]
- @property
- def area(self):
- w, h = self.wh
- return w * h
- @property
- def is_slope(self):
- """
- function: 10~60,-60~-10度之间,需要旋转图片,因为目前的检测模型对于倾斜角度的不能检测
- return: 需要旋转的角度 ---> tan
- """
- p0 = self.box[0]
- p1 = self.box[1]
- if p0[0] == p1[0]: # 如果是正常的那就不用转
- return 0
- slope = 1. * (p1[1] - p0[1]) / (p1[0] - p0[0])
- return slope
- @property
- def center(self):
- l, t = self.lt
- r, b = self.rb
- return [(r + l) / 2, (b + t) / 2]
- def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
- y_idx = 0 + is_horizontal
- x_idx = 1 - y_idx
- if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
- if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
- eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
- dist = abs(self.center[y_idx] - b.center[y_idx])
- return dist < eps
- # 行处理器
- class LineParser(object):
- def __init__(self, ocr_raw_result, filters=None):
- # self.rotate_angle = 0
- if filters is None:
- filters = [lambda x: x.is_slope]
- self.ocr_res = []
- for re in ocr_raw_result:
- o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
- # if any([f(o) for f in filters]): continue
- self.ocr_res.append(o)
- self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
- # 找到最大的检测框,大概率就是卡号所在位置
- # max_res = self.ocr_res[0]
- # for f in filters:
- # k = f(max_res)
- # self.rotate_angle = math.atan(k) * 180 / math.pi
- self.eps = self.avg_height * 0.7
- @property
- def is_horizontal(self):
- res = self.ocr_res
- wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
- return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
- # @property
- # def is_need_rotate(self):
- # return self.rotate_angle
- @property
- def avg_height(self):
- idx = self.is_horizontal + 0
- return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
- # 整体置信度
- @property
- def confidence(self):
- return np.mean([r.conf for r in self.ocr_res])
- # 处理器函数
- def parse(self, eps=40.0):
- # 存返回值
- res = []
- # 需要 处理的 OcrResult 对象 的长度
- length = len(self.ocr_res)
- # 遍历数组 并处理他
- for i in range(length):
- # 拿出 OcrResult对象的 第i值 -暂存-
- res_i = self.ocr_res[i]
- # 这次的 res_i 之前已经在结果集中,就继续下一个
- if any(map(lambda x: res_i in x, res)): continue
- # set() -> {}
- # 初始化一个集合 即-输出-
- res_row = set()
- for j in range(i, length):
- res_j = self.ocr_res[j]
- # 这次的 res_i 之前已经在结果集中,就继续下一个
- if any(map(lambda x: res_j in x, res)): continue
- if res_i.one_line(res_j, self.is_horizontal, self.eps):
- # LineParser 对象 不可以直接加入字典
- res_row.add(res_j)
- res.append(res_row)
- idx = self.is_horizontal + 0
- res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
- return res
- def detection_parse(self, eps=40.0):
- result = self.ocr_res
- if len(result) == 2:
- return result[0].one_line(result[1], True, self.eps)
|