|
@@ -0,0 +1,270 @@
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Optional
|
|
|
+import cv2
|
|
|
+import requests
|
|
|
+from mdutils.mdutils import MdUtils
|
|
|
+from dataclasses import dataclass
|
|
|
+import json
|
|
|
+import time
|
|
|
+import base64
|
|
|
+from itertools import chain
|
|
|
+from tqdm import tqdm
|
|
|
+from ocr_config import OCR_CONFIGS, Filed
|
|
|
+
|
|
|
+
|
|
|
+class Image:
|
|
|
+ def __init__(self, path: Path, rotate, is_rotate):
|
|
|
+ self._path = path
|
|
|
+ self.rotate = rotate
|
|
|
+ self._ocr_result = None
|
|
|
+ self.category = True
|
|
|
+ self.is_rotate = is_rotate
|
|
|
+ try:
|
|
|
+ self.gt_result = self.get_json()
|
|
|
+ except Exception as e:
|
|
|
+ print(self.json_path)
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
|
|
|
+
|
|
|
+ # 将方法转换为相同名称的只读属性
|
|
|
+ @property
|
|
|
+ def path(self):
|
|
|
+ return self._path
|
|
|
+
|
|
|
+ @path.setter
|
|
|
+ def path(self, path):
|
|
|
+ self._path = path
|
|
|
+
|
|
|
+ @property
|
|
|
+ def fn(self):
|
|
|
+ return self._path.stem
|
|
|
+
|
|
|
+ @property
|
|
|
+ def ocr_result(self):
|
|
|
+ return self._ocr_result
|
|
|
+
|
|
|
+ @ocr_result.setter
|
|
|
+ def ocr_result(self, value):
|
|
|
+ self._ocr_result = value
|
|
|
+
|
|
|
+ def get_gt_result(self, key):# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
|
|
+ if key == 'orientation':
|
|
|
+ if self.is_rotate:
|
|
|
+ return self.rotate + 1 if self.rotate is not None else 0
|
|
|
+ else:
|
|
|
+ return self.gt_result[key]
|
|
|
+ elif key in self.gt_result:
|
|
|
+ return self.gt_result[key]
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def json_path(self):
|
|
|
+ return self.path.parent / f'{self.path.stem}.json'
|
|
|
+
|
|
|
+ def save_image(self, img, rotate):
|
|
|
+ dst = self.path.parent.parent / (".ro_dst")
|
|
|
+ if not dst.exists(): dst.mkdir()
|
|
|
+ self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg'
|
|
|
+ # print('save image', self.path)
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
+ cv2.imwrite(str(self.path), img)
|
|
|
+ return self.path
|
|
|
+
|
|
|
+ def get_base64(self, rotate=None):
|
|
|
+ # print(self.path)
|
|
|
+ img = cv2.imread(str(self.path))
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
+ path = self.path
|
|
|
+ if rotate is not None:
|
|
|
+ img = cv2.rotate(img, rotate)
|
|
|
+ path = self.save_image(img, rotate)
|
|
|
+ # imencode 将图片编码到缓存,并保存到本地
|
|
|
+ with open(path, 'rb') as f:
|
|
|
+ return base64.encodebytes(f.read()).decode('utf-8')
|
|
|
+
|
|
|
+ def get_json(self):
|
|
|
+ with open(self.json_path, 'r') as f:
|
|
|
+ return json.load(f)
|
|
|
+
|
|
|
+
|
|
|
+def send_request(image: Image, ocr_name, ocr_address, image_type=None):
|
|
|
+ base64_str = image.get_base64(image.rotate)
|
|
|
+ config = OCR_CONFIGS[ocr_name][ocr_address]
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json',
|
|
|
+ 'Authorization': config.token
|
|
|
+ }
|
|
|
+ data = {
|
|
|
+ 'image': base64_str,
|
|
|
+ }
|
|
|
+ if image_type is not None:
|
|
|
+ data['image_type'] = image_type
|
|
|
+ response = requests.post(config.url, headers=headers, json=data)
|
|
|
+ return response.json()
|
|
|
+
|
|
|
+
|
|
|
+def parser_path(path: Path, rotate: bool):
|
|
|
+ name = time.strftime("%m-%d_", time.localtime()) + path.name
|
|
|
+ if rotate:
|
|
|
+ name = f'{name}_R.md'
|
|
|
+ return path.parent / name
|
|
|
+
|
|
|
+
|
|
|
+class Dataset(object):
|
|
|
+ def __init__(self, images_path, image_type, ocr_name, ocr_address, field, rotate=False):
|
|
|
+ self.image_type = image_type
|
|
|
+ self.ocr_name = ocr_name
|
|
|
+ self.ocr_address = ocr_address
|
|
|
+ self.images_path = images_path
|
|
|
+ self.image_list = []
|
|
|
+ # chain 迭代器,首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素
|
|
|
+ # eg:chain('ABC', 'DEF') --> A B C D E F
|
|
|
+
|
|
|
+ for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
|
|
|
+ if rotate:
|
|
|
+ self.image_list.extend(Image(p, r, rotate) for r in [None, 0, 1, 2])
|
|
|
+ else:
|
|
|
+ self.image_list.append(Image(p, None, rotate))
|
|
|
+
|
|
|
+ self.field = Filed.get(field)
|
|
|
+
|
|
|
+ self.correct = {k: 0 for k in self.field}
|
|
|
+ self.error = {k: 0 for k in self.field}
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.image_list)
|
|
|
+
|
|
|
+ def _evaluate_one(self, image: Image):
|
|
|
+ def _get_predict(r, key):
|
|
|
+ # isinstance() 函数来判断一个对象是否是一个已知的类型
|
|
|
+ if isinstance(r[key], dict):
|
|
|
+ return r[key]['text']
|
|
|
+ else:
|
|
|
+ return r[key]
|
|
|
+
|
|
|
+ if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1
|
|
|
+ r = send_request(image, self.ocr_name, self.ocr_address, self.image_type)
|
|
|
+ err_str = ''
|
|
|
+ if r['status'] == '000':
|
|
|
+ res = r['result']
|
|
|
+ for key in self.field:
|
|
|
+ # print('attr: ', key)
|
|
|
+ if key in res:
|
|
|
+ gt = image.get_gt_result(key)
|
|
|
+ predict = _get_predict(res, key)
|
|
|
+ # print(f'gt: {gt}, predict: {predict}')
|
|
|
+ if predict == gt:
|
|
|
+ self.correct[key] += 1
|
|
|
+ else:
|
|
|
+ image.category = False
|
|
|
+ self.error[key] += 1
|
|
|
+ err_str += f'-------{key}-------<br>正确:{gt}<br>返回:{predict}<br>'
|
|
|
+ if image.category:
|
|
|
+ image.ocr_result = image.gt_result
|
|
|
+ else:
|
|
|
+ image.ocr_result = err_str
|
|
|
+ else:
|
|
|
+ image.ocr_result = r['msg']
|
|
|
+ image.category = False
|
|
|
+ for key in self.field:
|
|
|
+ self.error[key] += 1
|
|
|
+
|
|
|
+ def __call__(self): # sourcery skip: yield-from
|
|
|
+ # yield 返回一个生成器
|
|
|
+ for image in self.image_list:
|
|
|
+ yield image
|
|
|
+
|
|
|
+ # 比较
|
|
|
+ def evaluate(self):
|
|
|
+ for image in tqdm(self.image_list):
|
|
|
+ self._evaluate_one(image)
|
|
|
+
|
|
|
+ # 计算总体准确度
|
|
|
+ @property
|
|
|
+ def accuracy(self):
|
|
|
+ return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
|
|
|
+
|
|
|
+ # 计算元素准确度
|
|
|
+ @property
|
|
|
+ def attrs_accuracy(self):
|
|
|
+ return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field}
|
|
|
+
|
|
|
+
|
|
|
+class MD(object):
|
|
|
+ def __init__(self, file_path: Path):
|
|
|
+ self.name = file_path.name
|
|
|
+ self.f = MdUtils(file_name=str(file_path))
|
|
|
+ self.field_table: List = ['字段', '正确率']
|
|
|
+ self.true_table: List = ['图片', '识别结果']
|
|
|
+ self.false_table: List = ['图片', '识别结果']
|
|
|
+ self.write_header(f'{self.name}测试报告')
|
|
|
+
|
|
|
+ def write_header(self, title, level=1):
|
|
|
+ self.f.new_header(level=level, title=title)
|
|
|
+
|
|
|
+ def write_total_accuracy(self, ds: Dataset):
|
|
|
+ def get_format_total_accuracy(ds: Dataset):
|
|
|
+ acc = ds.accuracy * 100
|
|
|
+ return "{:.2f}%".format(acc)
|
|
|
+
|
|
|
+ # 1. 拿到format之后的百分数
|
|
|
+ res = get_format_total_accuracy(ds)
|
|
|
+
|
|
|
+ # 2. 写入
|
|
|
+ self.f.new_paragraph(res)
|
|
|
+
|
|
|
+ def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'):
|
|
|
+ def format_table_accuracy(ds: Dataset):
|
|
|
+ table = ds.attrs_accuracy
|
|
|
+ for k, v in table.items():
|
|
|
+ acc = v * 100
|
|
|
+ table[k] = "{:.2f}%".format(acc)
|
|
|
+ return table
|
|
|
+
|
|
|
+ def dict_2_list(dic: dict):
|
|
|
+ l = []
|
|
|
+ for k, v in dic.items():
|
|
|
+ l.extend((k, v))
|
|
|
+ return l
|
|
|
+
|
|
|
+ table_dict = format_table_accuracy(ds)
|
|
|
+ table_list = dict_2_list(table_dict)
|
|
|
+ self.field_table.extend(table_list)
|
|
|
+
|
|
|
+ rows = len(self.field_table) // columns
|
|
|
+ self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align)
|
|
|
+
|
|
|
+ def write_table_result(self, ds: Dataset, columns=2, text_align='center'):
|
|
|
+ for image in ds.image_list:
|
|
|
+ md_image = self.f.new_inline_image(text='', path=f'{image.path.parent.name}/{image.path.name}')
|
|
|
+ if image.category:
|
|
|
+ self.true_table.extend([md_image, image.ocr_result])
|
|
|
+ else:
|
|
|
+ self.false_table.extend([md_image, image.ocr_result])
|
|
|
+
|
|
|
+ true_rows = len(self.true_table) // columns
|
|
|
+ false_rows = len(self.false_table) // columns
|
|
|
+ self.write_header('True')
|
|
|
+ self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align)
|
|
|
+ self.write_header('False')
|
|
|
+ self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align='left')
|
|
|
+
|
|
|
+# if __name__ == '__main__':
|
|
|
+# markdown = MD('英语等级证书')
|
|
|
+#
|
|
|
+# dataset = Dataset(Path(''), 'cet', 'local', False)
|
|
|
+# print(len(dataset))
|
|
|
+# for d in dataset():
|
|
|
+# print(d)
|
|
|
+#
|
|
|
+# dataset.evaluate()
|
|
|
+# print(dataset.accuracy)
|
|
|
+#
|
|
|
+# markdown.write_total_accuracy(dataset)
|
|
|
+# markdown.write_table_accuracy(dataset)
|
|
|
+# markdown.write_table_result(dataset)
|
|
|
+#
|
|
|
+# markdown.f.create_md_file()
|