from pathlib import Path
from typing import List, Optional
import cv2
import requests
from mdutils.mdutils import MdUtils
from dataclasses import dataclass
import json
import time
import base64
from itertools import chain
from tqdm import tqdm
import numpy as np
from ocr_config import OCR_CONFIGS
class Image:
def __init__(self, path: Path, rotate):
self._path = path
self.rotate = rotate
self._ocr_result = None
self.category = True
try:
self.gt_result = self.get_json()
except Exception as e:
print(self.json_path)
raise e
def __repr__(self):
return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
@property
def path(self):
return self._path
@path.setter
def path(self, path):
self._path = path
@property
def fn(self):
return self._path.stem
@property
def ocr_result(self):
return self._ocr_result
@ocr_result.setter
def ocr_result(self, value):
self._ocr_result = value
def get_gt_result(self, key):
if key == 'orientation':
return self.rotate + 1 if self.rotate is not None else 0
elif key in self.gt_result:
return self.gt_result[key]
else:
return None
@property
def json_path(self):
return self.path.parent / f'{self.path.stem}.json'
def save_image(self, img, rotate):
dst = self.path.parent.parent / (".ro_dst")
if not dst.exists(): dst.mkdir()
self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg'
# print('save image', self.path)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(self.path), img)
def get_base64(self, rotate=None):
# print(self.path)
img = cv2.imread(str(self.path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if rotate is not None:
img = cv2.rotate(img, rotate)
self.save_image(img, rotate)
_, img = cv2.imencode('.jpg', img)
img_str = base64.b64encode(img).decode('utf-8')
return img_str
def get_json(self):
with open(self.json_path, 'r') as f:
return json.load(f)
def send_request(image: Image, ocr_name, ocr_address, image_type=None):
base64_str = image.get_base64(image.rotate)
config = OCR_CONFIGS[ocr_name][ocr_address]
headers = {
'Content-Type': 'application/json',
'Authorization': config.token
}
data = {
'image': base64_str,
}
if image_type is not None:
data['image_type'] = image_type
response = requests.post(config.url, headers=headers, json=data)
return response.json()
def parser_path(path: Path, rotate: bool):
name = time.strftime("%m-%d_", time.localtime()) + path.name
if rotate:
name = f'{name}_R.md'
return path.parent / name
class Dataset(object):
def __init__(self, images_path, image_type, ocr_name, ocr_address, rotate=False):
self.image_type = image_type
self.ocr_name = ocr_name
self.ocr_address = ocr_address
self.images_path = images_path
self.image_list = []
for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
if rotate:
self.image_list.extend(Image(p, r) for r in [None, 0, 1, 2])
else:
self.image_list.append(Image(p, None))
self.field = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
# if self.image_type:
# self.field = ['orientation', 'type', 'address', 'address_province', 'address_city', 'address_region',
# 'address_detail']
# else:
# self.field = ['orientation', 'name', 'id', 'gender', 'birthplace', 'birthplace_province', 'birthplace_city',
# 'birthplace_region', 'native_place', 'native_place_province', 'native_place_city',
# 'native_place_region', 'blood_type', 'religion']
self.correct = {k: 0 for k in self.field}
self.error = {k: 0 for k in self.field}
def __len__(self):
return len(self.image_list)
def _evaluate_one(self, image: Image):
def _get_predict(r, key):
if isinstance(r[key], dict):
return r[key]['text']
else:
return r[key]
if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1
r = send_request(image, self.ocr_name, self.ocr_address, self.image_type)
err_str = ''
if r['status'] == '000':
res = r['result']
for key in self.field:
# print('attr: ', key)
if key in res:
gt = image.get_gt_result(key)
predict = _get_predict(res, key)
# print(f'gt: {gt}, predict: {predict}')
if predict == gt:
self.correct[key] += 1
else:
image.category = False
self.error[key] += 1
err_str += f'正确:{gt}
返回:{predict}
'
if image.category:
image.ocr_result = image.gt_result
else:
image.ocr_result = err_str
else:
image.ocr_result = r['msg']
image.category = False
for key in self.field:
self.error[key] += 1
def __call__(self): # sourcery skip: yield-from
# yield 返回一个生成器
for image in self.image_list:
yield image
# 比较
def evaluate(self):
for image in tqdm(self.image_list):
self._evaluate_one(image)
# 计算总体准确度
@property
def accuracy(self):
return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
# 计算元素准确度
@property
def attrs_accuracy(self):
return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field}
class MD(object):
def __init__(self, file_path: Path):
self.name = file_path.name
self.f = MdUtils(file_name=str(file_path))
self.field_table: List = ['字段', '正确率']
self.true_table: List = ['图片', '识别结果']
self.false_table: List = ['图片', '识别结果']
self.write_header(f'{self.name}测试报告')
def write_header(self, title, level=1):
self.f.new_header(level=level, title=title)
def write_total_accuracy(self, ds: Dataset):
def get_format_total_accuracy(ds: Dataset):
acc = ds.accuracy * 100
return "{:.2f}%".format(acc)
# 1. 拿到format之后的百分数
res = get_format_total_accuracy(ds)
# 2. 写入
self.f.new_paragraph(res)
def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'):
def format_table_accuracy(ds: Dataset):
table = ds.attrs_accuracy
for k, v in table.items():
acc = v * 100
table[k] = "{:.2f}%".format(acc)
return table
def dict_2_list(dic: dict):
l = []
for k, v in dic.items():
l.extend((k, v))
return l
table_dict = format_table_accuracy(ds)
table_list = dict_2_list(table_dict)
self.field_table.extend(table_list)
rows = len(self.field_table) // columns
self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align)
def write_table_result(self, ds: Dataset, columns=2, text_align='center'):
for image in ds.image_list:
md_image = self.f.new_inline_image(text='', path=str(image.path))
if image.category:
self.true_table.extend([md_image, image.ocr_result])
else:
self.false_table.extend([md_image, image.ocr_result])
true_rows = len(self.true_table) // columns
false_rows = len(self.false_table) // columns
self.write_header('True')
self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align)
self.write_header('False')
self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align=text_align)
# if __name__ == '__main__':
# markdown = MD('英语等级证书')
#
# dataset = Dataset(Path(''), 'cet', 'local', False)
# print(len(dataset))
# for d in dataset():
# print(d)
#
# dataset.evaluate()
# print(dataset.accuracy)
#
# markdown.write_total_accuracy(dataset)
# markdown.write_table_accuracy(dataset)
# markdown.write_table_result(dataset)
#
# markdown.f.create_md_file()