from pathlib import Path
from typing import List, Optional
import cv2
import requests
from dataclasses import dataclass
import json
import time
import base64
from itertools import chain
from tqdm import tqdm
@dataclass
class RequestConfig:
url: str
token: str
local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
test_config = RequestConfig(url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
CONFIGS = {
'local': local_config,
'test': test_config,
'sb': sb_config
}
CONFIG_STR = 'local'
IMAGE_TYPE = None
IMAGE_PATH = Path('images/cet6/')
# class MarkdownTable(object):
# def __init__(self, name):
# self.name = name
# self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
# self.field_table = ['字段', '正确率']
# self.true_table = ['图片', '识别结果']
# self.false_table = ['图片', '识别结果']
# def add_field_table(self, fields: List):
# self.field_table.extend(fields)
# def add_true_table(self, image_and_field: List):
# self.true_table.extend(image_and_field)
# def add_false_table(self, image_and_field: List):
# self.false_table.extend(image_and_field)
class Image:
def __init__(self, path: Path, rotate):
self._path = path
self.rotate = rotate
self._ocr_result = None
self.cate = True
try:
self.gt_result = self.get_json()
except Exception as e:
print(self.json_path)
raise e
def __repr__(self):
return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
@property
def path(self):
return self._path
@path.setter
def path(self, path):
self._path = path
@property
def fn(self):
return self._path.stem
@property
def ocr_result(self):
return self._ocr_result
@ocr_result.setter
def ocr_result(self, value):
self._ocr_result = value
def get_gt_result(self, key):
if key in self.gt_result:
return self.gt_result[key]
else:
return None
@property
def json_path(self):
return self.path.parent / f'{self.path.stem}.json'
def save_image(self, img, rotate=None):
dire = self.path.parent / (".ro_dire")
if not dire.exists(): dire.mkdir()
self.path = dire / f'{self.path.stem}-{rotate+1}.jpg'
cv2.imwrite(self.path, img)
def get_base64(self, rotate=None):
print(self.path)
img = cv2.imread(str(self.path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if rotate is not None:
img = cv2.rotate(img, rotate)
self.save_image(img)
_, img = cv2.imencode('.jpg', img)
img_str = base64.b64encode(img).decode('utf-8')
return img_str
def get_json(self):
with open(self.json_path, 'r') as f:
return json.load(f)
def send_request(base64_str, config_str, image_type=None):
config = CONFIGS[config_str]
headers = {
'Content-Type': 'application/json',
'Authorization': config.token
}
data = {
'image': base64_str,
}
if image_type:
data['image_type'] = image_type
response = requests.post(config.url, headers=headers, json=data)
return response.json()
class Dataset(object):
def __init__(self, image_path, image_type, config_str, rotate=False):
self.image_type = image_type
self.config_str = config_str
self.image_path = image_path
self.image_list = []
for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
if rotate:
for rotate in [None, 0, 1, 2]:
self.image_list.append(Image(p, rotate))
else:
self.image_list.append(Image(p, None))
self.attrs = ['orientation','name','id','language', 'level', 'exam_time', 'score']
self.correct = {k: 0 for k in self.attrs}
self.error = {k: 0 for k in self.attrs}
def __len__(self):
return len(self.image_list)
def _evaluate_one(self, image: Image):
def _get_predict(r, key):
if isinstance(r[key], dict):
return r[key]['text']
else:
return r[key]
base64_str = image.get_base64()
r = send_request(base64_str, self.config_str, self.image_type)
print(r)
err_str = ''
if r['status'] == '000':
res = r['result']
for key in self.attrs:
if key in res:
gt = image.get_gt_result(key)
predict = _get_predict(res, key)
print(f'gt: {gt}, predict: {predict}')
if predict == gt:
self.correct[key] += 1
else:
image.cate = False
self.error[key] += 1
err_str += f'正确:{gt}
返回:{predict}
'
if image.cate:
image.ocr_result = r['result']
else:
image.ocr_result = err_str
else:
image.ocr_result = r['msg']
image.cate = False
for key in self.attrs:
self.error[key] += 1
def evaluate(self):
for image in tqdm(self.image_list):
self._evaluate_one(image)
def accuracy(self):
print(self.correct.values())
return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
def attrs_accuracy(self):
return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
if __name__ == '__main__':
dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, False)
print(len(dataset))
for d in dataset.image_list:
print(d)
dataset.evaluate()
print(dataset.accuracy())