|
@@ -1,6 +1,7 @@
|
|
import numpy as np
|
|
import numpy as np
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
+
|
|
# result 对象
|
|
# result 对象
|
|
@dataclass
|
|
@dataclass
|
|
class OcrResult(object):
|
|
class OcrResult(object):
|
|
@@ -30,6 +31,20 @@ class OcrResult(object):
|
|
r, b = self.rb
|
|
r, b = self.rb
|
|
return [r - l, b - t]
|
|
return [r - l, b - t]
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def area(self):
|
|
|
|
+ w, h = self.wh
|
|
|
|
+ return w * h
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def is_slope(self):
|
|
|
|
+ p0 = self.box[0]
|
|
|
|
+ p1 = self.box[1]
|
|
|
|
+ if p0[0] == p1[0]:
|
|
|
|
+ return False
|
|
|
|
+ slope = abs(1. * (p0[1] - p1[1]) / (p0[0] - p1[0]))
|
|
|
|
+ return 0.4 < slope < 2.5
|
|
|
|
+
|
|
@property
|
|
@property
|
|
def center(self):
|
|
def center(self):
|
|
l, t = self.lt
|
|
l, t = self.lt
|
|
@@ -37,25 +52,32 @@ class OcrResult(object):
|
|
return [(r + l) / 2, (b + t) / 2]
|
|
return [(r + l) / 2, (b + t) / 2]
|
|
|
|
|
|
def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
|
|
def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
|
|
- if is_horizontal:
|
|
|
|
- return abs(self.lt[1] - b.lt[1]) < eps
|
|
|
|
- else:
|
|
|
|
- return abs(self.rb[0] - b.rb[0]) < eps
|
|
|
|
|
|
+ y_idx = 0 + is_horizontal
|
|
|
|
+ x_idx = 1 - y_idx
|
|
|
|
+ if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
|
|
|
|
+ if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
|
|
|
|
+ eps = 0.45 * (self.wh[y_idx] + b.wh[y_idx])
|
|
|
|
+ dist = abs(self.center[y_idx] - b.center[y_idx])
|
|
|
|
+ return dist < eps
|
|
|
|
|
|
|
|
|
|
# 行处理器
|
|
# 行处理器
|
|
class LineParser(object):
|
|
class LineParser(object):
|
|
- def __init__(self, ocr_raw_result):
|
|
|
|
- # self.is_horizontal = ocr_raw_result.is_horizontal
|
|
|
|
|
|
+ def __init__(self, ocr_raw_result, filters=None):
|
|
|
|
+ if filters is None:
|
|
|
|
+ filters = [lambda x: x.is_slope]
|
|
self.ocr_res = []
|
|
self.ocr_res = []
|
|
for re in ocr_raw_result:
|
|
for re in ocr_raw_result:
|
|
o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
|
|
o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
|
|
|
|
+ if any([f(o) for f in filters]): continue
|
|
self.ocr_res.append(o)
|
|
self.ocr_res.append(o)
|
|
|
|
+ # for f in filters:
|
|
|
|
+ # self.ocr_res = list(filter(f, self.ocr_res))
|
|
|
|
+ self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
|
|
self.eps = self.avg_height * 0.86
|
|
self.eps = self.avg_height * 0.86
|
|
|
|
|
|
- # 判断是否水平
|
|
|
|
@property
|
|
@property
|
|
- def is_horizontal(self) -> bool:
|
|
|
|
|
|
+ def is_horizontal(self):
|
|
res = self.ocr_res
|
|
res = self.ocr_res
|
|
wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
|
|
wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
|
|
return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
|
|
return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
|
|
@@ -65,6 +87,7 @@ class LineParser(object):
|
|
idx = self.is_horizontal + 0
|
|
idx = self.is_horizontal + 0
|
|
return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
|
|
return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
|
|
|
|
|
|
|
|
+ # 整体置信度
|
|
@property
|
|
@property
|
|
def confidence(self):
|
|
def confidence(self):
|
|
return np.mean([r.conf for r in self.ocr_res])
|
|
return np.mean([r.conf for r in self.ocr_res])
|
|
@@ -86,13 +109,6 @@ class LineParser(object):
|
|
# 拿出 OcrResult对象的 第i值 -暂存-
|
|
# 拿出 OcrResult对象的 第i值 -暂存-
|
|
res_i = self.ocr_res[i]
|
|
res_i = self.ocr_res[i]
|
|
|
|
|
|
- # any:-> True
|
|
|
|
- # -input: 可迭代对象 | -output: bool
|
|
|
|
- # -如果iterable的任何元素为true,则返回true。如果iterable为空,则返回false。 -与🚪-
|
|
|
|
- # map: -> [False, False, False, False, True, True, False, False]
|
|
|
|
- # -input: (函数, 可迭代对象) | -output: 可迭代对象
|
|
|
|
- # -把 res 喂给lambda --lambda返回True的值--> 输出 新的可迭代对象
|
|
|
|
-
|
|
|
|
# 这次的 res_i 之前已经在结果集中,就继续下一个
|
|
# 这次的 res_i 之前已经在结果集中,就继续下一个
|
|
if any(map(lambda x: res_i in x, res)): continue
|
|
if any(map(lambda x: res_i in x, res)): continue
|
|
|
|
|
|
@@ -102,10 +118,18 @@ class LineParser(object):
|
|
|
|
|
|
for j in range(i, length):
|
|
for j in range(i, length):
|
|
res_j = self.ocr_res[j]
|
|
res_j = self.ocr_res[j]
|
|
|
|
+ # 这次的 res_i 之前已经在结果集中,就继续下一个
|
|
|
|
+ if any(map(lambda x: res_j in x, res)): continue
|
|
|
|
+
|
|
if res_i.one_line(res_j, self.is_horizontal, self.eps):
|
|
if res_i.one_line(res_j, self.is_horizontal, self.eps):
|
|
# LineParser 对象 不可以直接加入字典
|
|
# LineParser 对象 不可以直接加入字典
|
|
|
|
|
|
res_row.add(res_j)
|
|
res_row.add(res_j)
|
|
res.append(res_row)
|
|
res.append(res_row)
|
|
idx = self.is_horizontal + 0
|
|
idx = self.is_horizontal + 0
|
|
- return sorted([list(r) for r in res], key=lambda x: x[0].lt[idx])
|
|
|
|
|
|
+ res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
|
|
|
|
+ for row in res:
|
|
|
|
+ print('---')
|
|
|
|
+ print(''.join([r.txt for r in row]))
|
|
|
|
+ return res
|
|
|
|
+
|