|
@@ -1,273 +0,0 @@
|
|
|
-from pathlib import Path
|
|
|
-from typing import List, Optional
|
|
|
-import cv2
|
|
|
-import requests
|
|
|
-from mdutils.mdutils import MdUtils
|
|
|
-from dataclasses import dataclass
|
|
|
-import json
|
|
|
-import time
|
|
|
-import base64
|
|
|
-from itertools import chain
|
|
|
-from tqdm import tqdm
|
|
|
-from ocr_config import OCR_CONFIGS, Filed
|
|
|
-
|
|
|
-
|
|
|
-class Image:
|
|
|
- def __init__(self, path: Path, rotate, is_rotate):
|
|
|
- self._path = path
|
|
|
- self.rotate = rotate
|
|
|
- self._ocr_result = None
|
|
|
- self.category = True
|
|
|
- self.is_rotate = is_rotate
|
|
|
- try:
|
|
|
- self.gt_result = self.get_json()
|
|
|
- except Exception as e:
|
|
|
- print(self.json_path)
|
|
|
- raise e
|
|
|
-
|
|
|
- def __repr__(self):
|
|
|
- return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
|
|
|
-
|
|
|
- # 将方法转换为相同名称的只读属性
|
|
|
- @property
|
|
|
- def path(self):
|
|
|
- return self._path
|
|
|
-
|
|
|
- @path.setter
|
|
|
- def path(self, path):
|
|
|
- self._path = path
|
|
|
-
|
|
|
- @property
|
|
|
- def fn(self):
|
|
|
- return self._path.stem
|
|
|
-
|
|
|
- @property
|
|
|
- def ocr_result(self):
|
|
|
- return self._ocr_result
|
|
|
-
|
|
|
- @ocr_result.setter
|
|
|
- def ocr_result(self, value):
|
|
|
- self._ocr_result = value
|
|
|
-
|
|
|
- def get_gt_result(self, key):# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
|
|
- if key == 'orientation':
|
|
|
- if self.is_rotate:
|
|
|
- return self.rotate + 1 if self.rotate is not None else 0
|
|
|
- else:
|
|
|
- return self.gt_result[key]
|
|
|
- elif key in self.gt_result:
|
|
|
- return self.gt_result[key]
|
|
|
- else:
|
|
|
- return None
|
|
|
-
|
|
|
- @property
|
|
|
- def json_path(self):
|
|
|
- return self.path.parent / f'{self.path.stem}.json'
|
|
|
-
|
|
|
- def save_image(self, img, rotate):
|
|
|
- dst = self.path.parent.parent / (".ro_dst")
|
|
|
- if not dst.exists(): dst.mkdir()
|
|
|
- self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg'
|
|
|
- # print('save image', self.path)
|
|
|
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
- cv2.imwrite(str(self.path), img)
|
|
|
- return self.path
|
|
|
-
|
|
|
- def get_base64(self, rotate=None):
|
|
|
- # print(self.path)
|
|
|
- img = cv2.imread(str(self.path))
|
|
|
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
- path = self.path
|
|
|
- if rotate is not None:
|
|
|
- img = cv2.rotate(img, rotate)
|
|
|
- path = self.save_image(img, rotate)
|
|
|
- # imencode 将图片编码到缓存,并保存到本地
|
|
|
- with open(path, 'rb') as f:
|
|
|
- return base64.encodebytes(f.read()).decode('utf-8')
|
|
|
-
|
|
|
- def get_json(self):
|
|
|
- with open(self.json_path, 'r') as f:
|
|
|
- return json.load(f)
|
|
|
-
|
|
|
-
|
|
|
-def send_request(image: Image, ocr_name, ocr_address, image_type=None):
|
|
|
- base64_str = image.get_base64(image.rotate)
|
|
|
- config = OCR_CONFIGS[ocr_name][ocr_address]
|
|
|
- headers = {
|
|
|
- 'Content-Type': 'application/json',
|
|
|
- 'Authorization': f'Bearer {config.token}'
|
|
|
- }
|
|
|
- data = {
|
|
|
- 'image': base64_str,
|
|
|
- }
|
|
|
- if image_type is not None:
|
|
|
- data['image_type'] = image_type
|
|
|
- response = requests.post(config.url, headers=headers, json=data)
|
|
|
- return response.json()
|
|
|
-
|
|
|
-
|
|
|
-def parser_path(path: Path, rotate: bool):
|
|
|
- name = time.strftime("%m-%d_", time.localtime()) + path.name
|
|
|
- if rotate:
|
|
|
- name = f'{name}_R.md'
|
|
|
- return path.parent / name
|
|
|
-
|
|
|
-
|
|
|
-class Dataset(object):
|
|
|
- def __init__(self, images_path, image_type, ocr_name, ocr_address, field, rotate=False):
|
|
|
- self.image_type = image_type
|
|
|
- self.ocr_name = ocr_name
|
|
|
- self.ocr_address = ocr_address
|
|
|
- self.images_path = images_path
|
|
|
- self.image_list = []
|
|
|
- # chain 迭代器,首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素
|
|
|
- # eg:chain('ABC', 'DEF') --> A B C D E F
|
|
|
-
|
|
|
- for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
|
|
|
- if rotate:
|
|
|
- self.image_list.extend(Image(p, r, rotate) for r in [None, 0, 1, 2])
|
|
|
- else:
|
|
|
- self.image_list.append(Image(p, None, rotate))
|
|
|
-
|
|
|
- if ocr_name == 'regbook':
|
|
|
- self.field = Filed.get(field + str(image_type))
|
|
|
- else:
|
|
|
- self.field = Filed.get(field)
|
|
|
-
|
|
|
- self.correct = {k: 0 for k in self.field}
|
|
|
- self.error = {k: 0 for k in self.field}
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.image_list)
|
|
|
-
|
|
|
- def _evaluate_one(self, image: Image):
|
|
|
- def _get_predict(r, key):
|
|
|
- # isinstance() 函数来判断一个对象是否是一个已知的类型
|
|
|
- if isinstance(r[key], dict):
|
|
|
- return r[key]['text']
|
|
|
- else:
|
|
|
- return r[key]
|
|
|
-
|
|
|
- if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1
|
|
|
- r = send_request(image, self.ocr_name, self.ocr_address, self.image_type)
|
|
|
- err_str = ''
|
|
|
- if r['status'] == '000':
|
|
|
- res = r['result']
|
|
|
- for key in self.field:
|
|
|
- # print('attr: ', key)
|
|
|
- if key in res:
|
|
|
- gt = image.get_gt_result(key)
|
|
|
- predict = _get_predict(res, key)
|
|
|
- # print(f'gt: {gt}, predict: {predict}')
|
|
|
- if predict == gt:
|
|
|
- self.correct[key] += 1
|
|
|
- else:
|
|
|
- image.category = False
|
|
|
- self.error[key] += 1
|
|
|
- err_str += f'-------{key}-------<br>正确:{gt}<br>返回:{predict}<br>'
|
|
|
- if image.category:
|
|
|
- image.ocr_result = image.gt_result
|
|
|
- else:
|
|
|
- image.ocr_result = err_str
|
|
|
- else:
|
|
|
- image.ocr_result = r['msg']
|
|
|
- image.category = False
|
|
|
- for key in self.field:
|
|
|
- self.error[key] += 1
|
|
|
-
|
|
|
- def __call__(self): # sourcery skip: yield-from
|
|
|
- # yield 返回一个生成器
|
|
|
- for image in self.image_list:
|
|
|
- yield image
|
|
|
-
|
|
|
- # 比较
|
|
|
- def evaluate(self):
|
|
|
- for image in tqdm(self.image_list):
|
|
|
- self._evaluate_one(image)
|
|
|
-
|
|
|
- # 计算总体准确度
|
|
|
- @property
|
|
|
- def accuracy(self):
|
|
|
- return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
|
|
|
-
|
|
|
- # 计算元素准确度
|
|
|
- @property
|
|
|
- def attrs_accuracy(self):
|
|
|
- return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field}
|
|
|
-
|
|
|
-
|
|
|
-class MD(object):
|
|
|
- def __init__(self, file_path: Path):
|
|
|
- self.name = file_path.name
|
|
|
- self.f = MdUtils(file_name=str(file_path))
|
|
|
- self.field_table: List = ['字段', '正确率']
|
|
|
- self.true_table: List = ['图片', '识别结果']
|
|
|
- self.false_table: List = ['图片', '识别结果']
|
|
|
- self.write_header(f'{self.name}测试报告')
|
|
|
-
|
|
|
- def write_header(self, title, level=1):
|
|
|
- self.f.new_header(level=level, title=title)
|
|
|
-
|
|
|
- def write_total_accuracy(self, ds: Dataset):
|
|
|
- def get_format_total_accuracy(ds: Dataset):
|
|
|
- acc = ds.accuracy * 100
|
|
|
- return "{:.2f}%".format(acc)
|
|
|
-
|
|
|
- # 1. 拿到format之后的百分数
|
|
|
- res = get_format_total_accuracy(ds)
|
|
|
-
|
|
|
- # 2. 写入
|
|
|
- self.f.new_paragraph(res)
|
|
|
-
|
|
|
- def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'):
|
|
|
- def format_table_accuracy(ds: Dataset):
|
|
|
- table = ds.attrs_accuracy
|
|
|
- for k, v in table.items():
|
|
|
- acc = v * 100
|
|
|
- table[k] = "{:.2f}%".format(acc)
|
|
|
- return table
|
|
|
-
|
|
|
- def dict_2_list(dic: dict):
|
|
|
- l = []
|
|
|
- for k, v in dic.items():
|
|
|
- l.extend((k, v))
|
|
|
- return l
|
|
|
-
|
|
|
- table_dict = format_table_accuracy(ds)
|
|
|
- table_list = dict_2_list(table_dict)
|
|
|
- self.field_table.extend(table_list)
|
|
|
-
|
|
|
- rows = len(self.field_table) // columns
|
|
|
- self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align)
|
|
|
-
|
|
|
- def write_table_result(self, ds: Dataset, columns=2, text_align='center'):
|
|
|
- for image in ds.image_list:
|
|
|
- md_image = self.f.new_inline_image(text='', path=f'{image.path.parent.name}/{image.path.name}')
|
|
|
- if image.category:
|
|
|
- self.true_table.extend([md_image, image.ocr_result])
|
|
|
- else:
|
|
|
- self.false_table.extend([md_image, image.ocr_result])
|
|
|
-
|
|
|
- true_rows = len(self.true_table) // columns
|
|
|
- false_rows = len(self.false_table) // columns
|
|
|
- self.write_header('True')
|
|
|
- self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align)
|
|
|
- self.write_header('False')
|
|
|
- self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align='left')
|
|
|
-
|
|
|
-# if __name__ == '__main__':
|
|
|
-# markdown = MD('英语等级证书')
|
|
|
-#
|
|
|
-# dataset = Dataset(Path(''), 'cet', 'local', False)
|
|
|
-# print(len(dataset))
|
|
|
-# for d in dataset():
|
|
|
-# print(d)
|
|
|
-#
|
|
|
-# dataset.evaluate()
|
|
|
-# print(dataset.accuracy)
|
|
|
-#
|
|
|
-# markdown.write_total_accuracy(dataset)
|
|
|
-# markdown.write_table_accuracy(dataset)
|
|
|
-# markdown.write_table_result(dataset)
|
|
|
-#
|
|
|
-# markdown.f.create_md_file()
|