123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- from pathlib import Path
- from typing import List, Optional
- import cv2
- import requests
- from mdutils.mdutils import MdUtils
- from dataclasses import dataclass
- import json
- import time
- import base64
- from itertools import chain
- from tqdm import tqdm
- from ocr_config import OCR_CONFIGS, Filed
- class Image:
- def __init__(self, path: Path, rotate, is_rotate):
- self._path = path
- self.rotate = rotate
- self._ocr_result = None
- self.category = True
- self.is_rotate = is_rotate
- try:
- self.gt_result = self.get_json()
- except Exception as e:
- print(self.json_path)
- raise e
- def __repr__(self):
- return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
- # 将方法转换为相同名称的只读属性
- @property
- def path(self):
- return self._path
- @path.setter
- def path(self, path):
- self._path = path
- @property
- def fn(self):
- return self._path.stem
- @property
- def ocr_result(self):
- return self._ocr_result
- @ocr_result.setter
- def ocr_result(self, value):
- self._ocr_result = value
- def get_gt_result(self, key):# sourcery skip: merge-duplicate-blocks, remove-redundant-if
- if key == 'orientation':
- if self.is_rotate:
- return self.rotate + 1 if self.rotate is not None else 0
- else:
- return self.gt_result[key]
- elif key in self.gt_result:
- return self.gt_result[key]
- else:
- return None
- @property
- def json_path(self):
- return self.path.parent / f'{self.path.stem}.json'
- def save_image(self, img, rotate):
- dst = self.path.parent.parent / (".ro_dst")
- if not dst.exists(): dst.mkdir()
- self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg'
- # print('save image', self.path)
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- cv2.imwrite(str(self.path), img)
- return self.path
- def get_base64(self, rotate=None):
- # print(self.path)
- img = cv2.imread(str(self.path))
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- path = self.path
- if rotate is not None:
- img = cv2.rotate(img, rotate)
- path = self.save_image(img, rotate)
- # imencode 将图片编码到缓存,并保存到本地
- with open(path, 'rb') as f:
- return base64.encodebytes(f.read()).decode('utf-8')
- def get_json(self):
- with open(self.json_path, 'r') as f:
- return json.load(f)
- def send_request(image: Image, ocr_name, ocr_address, image_type=None):
- base64_str = image.get_base64(image.rotate)
- config = OCR_CONFIGS[ocr_name][ocr_address]
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': config.token
- }
- data = {
- 'image': base64_str,
- }
- if image_type is not None:
- data['image_type'] = image_type
- response = requests.post(config.url, headers=headers, json=data)
- return response.json()
- def parser_path(path: Path, rotate: bool):
- name = time.strftime("%m-%d_", time.localtime()) + path.name
- if rotate:
- name = f'{name}_R.md'
- return path.parent / name
- class Dataset(object):
- def __init__(self, images_path, image_type, ocr_name, ocr_address, field, rotate=False):
- self.image_type = image_type
- self.ocr_name = ocr_name
- self.ocr_address = ocr_address
- self.images_path = images_path
- self.image_list = []
- # chain 迭代器,首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素
- # eg:chain('ABC', 'DEF') --> A B C D E F
- for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
- if rotate:
- self.image_list.extend(Image(p, r, rotate) for r in [None, 0, 1, 2])
- else:
- self.image_list.append(Image(p, None, rotate))
- self.field = Filed.get(field)
- self.correct = {k: 0 for k in self.field}
- self.error = {k: 0 for k in self.field}
- def __len__(self):
- return len(self.image_list)
- def _evaluate_one(self, image: Image):
- def _get_predict(r, key):
- # isinstance() 函数来判断一个对象是否是一个已知的类型
- if isinstance(r[key], dict):
- return r[key]['text']
- else:
- return r[key]
- if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1
- r = send_request(image, self.ocr_name, self.ocr_address, self.image_type)
- err_str = ''
- if r['status'] == '000':
- res = r['result']
- for key in self.field:
- # print('attr: ', key)
- if key in res:
- gt = image.get_gt_result(key)
- predict = _get_predict(res, key)
- # print(f'gt: {gt}, predict: {predict}')
- if predict == gt:
- self.correct[key] += 1
- else:
- image.category = False
- self.error[key] += 1
- err_str += f'-------{key}-------<br>正确:{gt}<br>返回:{predict}<br>'
- if image.category:
- image.ocr_result = image.gt_result
- else:
- image.ocr_result = err_str
- else:
- image.ocr_result = r['msg']
- image.category = False
- for key in self.field:
- self.error[key] += 1
- def __call__(self): # sourcery skip: yield-from
- # yield 返回一个生成器
- for image in self.image_list:
- yield image
- # 比较
- def evaluate(self):
- for image in tqdm(self.image_list):
- self._evaluate_one(image)
- # 计算总体准确度
- @property
- def accuracy(self):
- return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
- # 计算元素准确度
- @property
- def attrs_accuracy(self):
- return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field}
- class MD(object):
- def __init__(self, file_path: Path):
- self.name = file_path.name
- self.f = MdUtils(file_name=str(file_path))
- self.field_table: List = ['字段', '正确率']
- self.true_table: List = ['图片', '识别结果']
- self.false_table: List = ['图片', '识别结果']
- self.write_header(f'{self.name}测试报告')
- def write_header(self, title, level=1):
- self.f.new_header(level=level, title=title)
- def write_total_accuracy(self, ds: Dataset):
- def get_format_total_accuracy(ds: Dataset):
- acc = ds.accuracy * 100
- return "{:.2f}%".format(acc)
- # 1. 拿到format之后的百分数
- res = get_format_total_accuracy(ds)
- # 2. 写入
- self.f.new_paragraph(res)
- def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'):
- def format_table_accuracy(ds: Dataset):
- table = ds.attrs_accuracy
- for k, v in table.items():
- acc = v * 100
- table[k] = "{:.2f}%".format(acc)
- return table
- def dict_2_list(dic: dict):
- l = []
- for k, v in dic.items():
- l.extend((k, v))
- return l
- table_dict = format_table_accuracy(ds)
- table_list = dict_2_list(table_dict)
- self.field_table.extend(table_list)
- rows = len(self.field_table) // columns
- self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align)
- def write_table_result(self, ds: Dataset, columns=2, text_align='center'):
- for image in ds.image_list:
- md_image = self.f.new_inline_image(text='', path=f'{image.path.parent.name}/{image.path.name}')
- if image.category:
- self.true_table.extend([md_image, image.ocr_result])
- else:
- self.false_table.extend([md_image, image.ocr_result])
- true_rows = len(self.true_table) // columns
- false_rows = len(self.false_table) // columns
- self.write_header('True')
- self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align)
- self.write_header('False')
- self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align='left')
- # if __name__ == '__main__':
- # markdown = MD('英语等级证书')
- #
- # dataset = Dataset(Path(''), 'cet', 'local', False)
- # print(len(dataset))
- # for d in dataset():
- # print(d)
- #
- # dataset.evaluate()
- # print(dataset.accuracy)
- #
- # markdown.write_total_accuracy(dataset)
- # markdown.write_table_accuracy(dataset)
- # markdown.write_table_result(dataset)
- #
- # markdown.f.create_md_file()
|