from pathlib import Path from typing import List, Optional import cv2 import requests from mdutils.mdutils import MdUtils from dataclasses import dataclass import json import time import base64 from itertools import chain from tqdm import tqdm import numpy as np from ocr_config import OCR_CONFIGS class Image: def __init__(self, path: Path, rotate): self._path = path self.rotate = rotate self._ocr_result = None self.category = True try: self.gt_result = self.get_json() except Exception as e: print(self.json_path) raise e def __repr__(self): return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}' @property def path(self): return self._path @path.setter def path(self, path): self._path = path @property def fn(self): return self._path.stem @property def ocr_result(self): return self._ocr_result @ocr_result.setter def ocr_result(self, value): self._ocr_result = value def get_gt_result(self, key): if key == 'orientation': return self.rotate + 1 if self.rotate is not None else 0 elif key in self.gt_result: return self.gt_result[key] else: return None @property def json_path(self): return self.path.parent / f'{self.path.stem}.json' def save_image(self, img, rotate): dst = self.path.parent.parent / (".ro_dst") if not dst.exists(): dst.mkdir() self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg' # print('save image', self.path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) cv2.imwrite(str(self.path), img) def get_base64(self, rotate=None): # print(self.path) img = cv2.imread(str(self.path)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if rotate is not None: img = cv2.rotate(img, rotate) self.save_image(img, rotate) _, img = cv2.imencode('.jpg', img) img_str = base64.b64encode(img).decode('utf-8') return img_str def get_json(self): with open(self.json_path, 'r') as f: return json.load(f) def send_request(image: Image, ocr_name, ocr_address, image_type=None): base64_str = image.get_base64(image.rotate) config = OCR_CONFIGS[ocr_name][ocr_address] headers = { 'Content-Type': 'application/json', 'Authorization': config.token } data = { 'image': base64_str, } if image_type is not None: data['image_type'] = image_type response = requests.post(config.url, headers=headers, json=data) return response.json() def parser_path(path: Path, rotate: bool): name = time.strftime("%m-%d_", time.localtime()) + path.name if rotate: name = f'{name}_R.md' return path.parent / name class Dataset(object): def __init__(self, images_path, image_type, ocr_name, ocr_address, rotate=False): self.image_type = image_type self.ocr_name = ocr_name self.ocr_address = ocr_address self.images_path = images_path self.image_list = [] for p in chain(*[Path(self.images_path).rglob('*.jpg')]): if rotate: self.image_list.extend(Image(p, r) for r in [None, 0, 1, 2]) else: self.image_list.append(Image(p, None)) self.field = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score'] # if self.image_type: # self.field = ['orientation', 'type', 'address', 'address_province', 'address_city', 'address_region', # 'address_detail'] # else: # self.field = ['orientation', 'name', 'id', 'gender', 'birthplace', 'birthplace_province', 'birthplace_city', # 'birthplace_region', 'native_place', 'native_place_province', 'native_place_city', # 'native_place_region', 'blood_type', 'religion'] self.correct = {k: 0 for k in self.field} self.error = {k: 0 for k in self.field} def __len__(self): return len(self.image_list) def _evaluate_one(self, image: Image): def _get_predict(r, key): if isinstance(r[key], dict): return r[key]['text'] else: return r[key] if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1 r = send_request(image, self.ocr_name, self.ocr_address, self.image_type) err_str = '' if r['status'] == '000': res = r['result'] for key in self.field: # print('attr: ', key) if key in res: gt = image.get_gt_result(key) predict = _get_predict(res, key) # print(f'gt: {gt}, predict: {predict}') if predict == gt: self.correct[key] += 1 else: image.category = False self.error[key] += 1 err_str += f'正确:{gt}
返回:{predict}
' if image.category: image.ocr_result = image.gt_result else: image.ocr_result = err_str else: image.ocr_result = r['msg'] image.category = False for key in self.field: self.error[key] += 1 def __call__(self): # sourcery skip: yield-from # yield 返回一个生成器 for image in self.image_list: yield image # 比较 def evaluate(self): for image in tqdm(self.image_list): self._evaluate_one(image) # 计算总体准确度 @property def accuracy(self): return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values())) # 计算元素准确度 @property def attrs_accuracy(self): return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field} class MD(object): def __init__(self, file_path: Path): self.name = file_path.name self.f = MdUtils(file_name=str(file_path)) self.field_table: List = ['字段', '正确率'] self.true_table: List = ['图片', '识别结果'] self.false_table: List = ['图片', '识别结果'] self.write_header(f'{self.name}测试报告') def write_header(self, title, level=1): self.f.new_header(level=level, title=title) def write_total_accuracy(self, ds: Dataset): def get_format_total_accuracy(ds: Dataset): acc = ds.accuracy * 100 return "{:.2f}%".format(acc) # 1. 拿到format之后的百分数 res = get_format_total_accuracy(ds) # 2. 写入 self.f.new_paragraph(res) def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'): def format_table_accuracy(ds: Dataset): table = ds.attrs_accuracy for k, v in table.items(): acc = v * 100 table[k] = "{:.2f}%".format(acc) return table def dict_2_list(dic: dict): l = [] for k, v in dic.items(): l.extend((k, v)) return l table_dict = format_table_accuracy(ds) table_list = dict_2_list(table_dict) self.field_table.extend(table_list) rows = len(self.field_table) // columns self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align) def write_table_result(self, ds: Dataset, columns=2, text_align='center'): for image in ds.image_list: md_image = self.f.new_inline_image(text='', path=str(image.path)) if image.category: self.true_table.extend([md_image, image.ocr_result]) else: self.false_table.extend([md_image, image.ocr_result]) true_rows = len(self.true_table) // columns false_rows = len(self.false_table) // columns self.write_header('True') self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align) self.write_header('False') self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align=text_align) # if __name__ == '__main__': # markdown = MD('英语等级证书') # # dataset = Dataset(Path(''), 'cet', 'local', False) # print(len(dataset)) # for d in dataset(): # print(d) # # dataset.evaluate() # print(dataset.accuracy) # # markdown.write_total_accuracy(dataset) # markdown.write_table_accuracy(dataset) # markdown.write_table_result(dataset) # # markdown.f.create_md_file()