|
@@ -9,16 +9,16 @@ import time
|
|
|
import base64
|
|
|
from itertools import chain
|
|
|
from tqdm import tqdm
|
|
|
-import numpy as np
|
|
|
-from ocr_config import OCR_CONFIGS
|
|
|
+from ocr_config import OCR_CONFIGS, Filed
|
|
|
|
|
|
|
|
|
class Image:
|
|
|
- def __init__(self, path: Path, rotate):
|
|
|
+ def __init__(self, path: Path, rotate, is_rotate):
|
|
|
self._path = path
|
|
|
self.rotate = rotate
|
|
|
self._ocr_result = None
|
|
|
self.category = True
|
|
|
+ self.is_rotate = is_rotate
|
|
|
try:
|
|
|
self.gt_result = self.get_json()
|
|
|
except Exception as e:
|
|
@@ -28,6 +28,7 @@ class Image:
|
|
|
def __repr__(self):
|
|
|
return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
|
|
|
|
|
|
+ # 将方法转换为相同名称的只读属性
|
|
|
@property
|
|
|
def path(self):
|
|
|
return self._path
|
|
@@ -48,9 +49,12 @@ class Image:
|
|
|
def ocr_result(self, value):
|
|
|
self._ocr_result = value
|
|
|
|
|
|
- def get_gt_result(self, key):
|
|
|
+ def get_gt_result(self, key):# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
|
|
if key == 'orientation':
|
|
|
- return self.rotate + 1 if self.rotate is not None else 0
|
|
|
+ if self.is_rotate:
|
|
|
+ return self.rotate + 1 if self.rotate is not None else 0
|
|
|
+ else:
|
|
|
+ return self.gt_result[key]
|
|
|
elif key in self.gt_result:
|
|
|
return self.gt_result[key]
|
|
|
else:
|
|
@@ -67,16 +71,19 @@ class Image:
|
|
|
# print('save image', self.path)
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
cv2.imwrite(str(self.path), img)
|
|
|
+ return self.path
|
|
|
|
|
|
def get_base64(self, rotate=None):
|
|
|
# print(self.path)
|
|
|
img = cv2.imread(str(self.path))
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
+ path = self.path
|
|
|
if rotate is not None:
|
|
|
img = cv2.rotate(img, rotate)
|
|
|
- self.save_image(img, rotate)
|
|
|
- _, img = cv2.imencode('.jpg', img)
|
|
|
- return base64.b64encode(img).decode('utf-8')
|
|
|
+ path = self.save_image(img, rotate)
|
|
|
+ # imencode 将图片编码到缓存,并保存到本地
|
|
|
+ with open(path, 'rb') as f:
|
|
|
+ return base64.encodebytes(f.read()).decode('utf-8')
|
|
|
|
|
|
def get_json(self):
|
|
|
with open(self.json_path, 'r') as f:
|
|
@@ -107,26 +114,22 @@ def parser_path(path: Path, rotate: bool):
|
|
|
|
|
|
|
|
|
class Dataset(object):
|
|
|
- def __init__(self, images_path, image_type, ocr_name, ocr_address, rotate=False):
|
|
|
+ def __init__(self, images_path, image_type, ocr_name, ocr_address, field, rotate=False):
|
|
|
self.image_type = image_type
|
|
|
self.ocr_name = ocr_name
|
|
|
self.ocr_address = ocr_address
|
|
|
self.images_path = images_path
|
|
|
self.image_list = []
|
|
|
+ # chain 迭代器,首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素
|
|
|
+ # eg:chain('ABC', 'DEF') --> A B C D E F
|
|
|
+
|
|
|
for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
|
|
|
if rotate:
|
|
|
- self.image_list.extend(Image(p, r) for r in [None, 0, 1, 2])
|
|
|
+ self.image_list.extend(Image(p, r, rotate) for r in [None, 0, 1, 2])
|
|
|
else:
|
|
|
- self.image_list.append(Image(p, None))
|
|
|
+ self.image_list.append(Image(p, None, rotate))
|
|
|
|
|
|
- self.field = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
|
|
|
- # if self.image_type:
|
|
|
- # self.field = ['orientation', 'type', 'address', 'address_province', 'address_city', 'address_region',
|
|
|
- # 'address_detail']
|
|
|
- # else:
|
|
|
- # self.field = ['orientation', 'name', 'id', 'gender', 'birthplace', 'birthplace_province', 'birthplace_city',
|
|
|
- # 'birthplace_region', 'native_place', 'native_place_province', 'native_place_city',
|
|
|
- # 'native_place_region', 'blood_type', 'religion']
|
|
|
+ self.field = Filed.get(field)
|
|
|
|
|
|
self.correct = {k: 0 for k in self.field}
|
|
|
self.error = {k: 0 for k in self.field}
|
|
@@ -136,6 +139,7 @@ class Dataset(object):
|
|
|
|
|
|
def _evaluate_one(self, image: Image):
|
|
|
def _get_predict(r, key):
|
|
|
+ # isinstance() 函数来判断一个对象是否是一个已知的类型
|
|
|
if isinstance(r[key], dict):
|
|
|
return r[key]['text']
|
|
|
else:
|