import math import numpy as np from dataclasses import dataclass # result 对象 @dataclass class OcrResult(object): box: np.ndarray txt: str conf: float def __hash__(self): return hash(repr(self)) def __repr__(self): return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}' @property def lt(self): l, t = np.min(self.box, 0) return [l, t] @property def rb(self): r, b = np.max(self.box, 0) return [r, b] @property def wh(self): l, t = self.lt r, b = self.rb return [r - l, b - t] @property def area(self): w, h = self.wh return w * h @property def is_slope(self): """ function: 10~60,-60~-10度之间,需要旋转图片,因为目前的检测模型对于倾斜角度的不能检测 return: 需要旋转的角度 ---> tan """ p0 = self.box[0] p1 = self.box[1] if p0[0] == p1[0]: # 如果是正常的那就不用转 return 0 slope = 1. * (p1[1] - p0[1]) / (p1[0] - p0[0]) return slope @property def center(self): l, t = self.lt r, b = self.rb return [(r + l) / 2, (b + t) / 2] def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool: y_idx = 0 + is_horizontal x_idx = 1 - y_idx if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx]) dist = abs(self.center[y_idx] - b.center[y_idx]) return dist < eps # 行处理器 class LineParser(object): def __init__(self, ocr_raw_result, filters=None): # self.rotate_angle = 0 if filters is None: filters = [lambda x: x.is_slope] self.ocr_res = [] for re in ocr_raw_result: o = OcrResult(np.array(re[0]), re[1][0], re[1][1]) # if any([f(o) for f in filters]): continue self.ocr_res.append(o) self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True) # 找到最大的检测框,大概率就是卡号所在位置 # max_res = self.ocr_res[0] # for f in filters: # k = f(max_res) # self.rotate_angle = math.atan(k) * 180 / math.pi self.eps = self.avg_height * 0.7 @property def is_horizontal(self): res = self.ocr_res wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res]) return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1]) # @property # def is_need_rotate(self): # return self.rotate_angle @property def avg_height(self): idx = self.is_horizontal + 0 return np.mean(np.array([r.wh[idx] for r in self.ocr_res])) # 整体置信度 @property def confidence(self): return np.mean([r.conf for r in self.ocr_res]) # 处理器函数 def parse(self, eps=40.0): # 存返回值 res = [] # 需要 处理的 OcrResult 对象 的长度 length = len(self.ocr_res) # 遍历数组 并处理他 for i in range(length): # 拿出 OcrResult对象的 第i值 -暂存- res_i = self.ocr_res[i] # 这次的 res_i 之前已经在结果集中,就继续下一个 if any(map(lambda x: res_i in x, res)): continue # set() -> {} # 初始化一个集合 即-输出- res_row = set() for j in range(i, length): res_j = self.ocr_res[j] # 这次的 res_i 之前已经在结果集中,就继续下一个 if any(map(lambda x: res_j in x, res)): continue if res_i.one_line(res_j, self.is_horizontal, self.eps): # LineParser 对象 不可以直接加入字典 res_row.add(res_j) res.append(res_row) idx = self.is_horizontal + 0 res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx]) return res def detection_parse(self, eps=40.0): result = self.ocr_res if len(result) == 2: return result[0].one_line(result[1], True, self.eps)