from pathlib import Path from typing import List, Optional import cv2 import requests from dataclasses import dataclass import json import time import base64 from itertools import chain from tqdm import tqdm @dataclass class RequestConfig: url: str token: str local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='') test_config = RequestConfig(url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='9679c2b3-b90b-4029-a3c7-f347b4d242f7') sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636') CONFIGS = { 'local': local_config, 'test': test_config, 'sb': sb_config } CONFIG_STR = 'local' IMAGE_TYPE = None IMAGE_PATH = Path('images/cet6/') # class MarkdownTable(object): # def __init__(self, name): # self.name = name # self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name) # self.field_table = ['字段', '正确率'] # self.true_table = ['图片', '识别结果'] # self.false_table = ['图片', '识别结果'] # def add_field_table(self, fields: List): # self.field_table.extend(fields) # def add_true_table(self, image_and_field: List): # self.true_table.extend(image_and_field) # def add_false_table(self, image_and_field: List): # self.false_table.extend(image_and_field) class Image: def __init__(self, path: Path, rotate): self._path = path self.rotate = rotate self._ocr_result = None self.cate = True try: self.gt_result = self.get_json() except Exception as e: print(self.json_path) raise e def __repr__(self): return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}' @property def path(self): return self._path @path.setter def path(self, path): self._path = path @property def fn(self): return self._path.stem @property def ocr_result(self): return self._ocr_result @ocr_result.setter def ocr_result(self, value): self._ocr_result = value def get_gt_result(self, key): if key in self.gt_result: return self.gt_result[key] else: return None @property def json_path(self): return self.path.parent / f'{self.path.stem}.json' def save_image(self, img, rotate=None): dire = self.path.parent / (".ro_dire") if not dire.exists(): dire.mkdir() self.path = dire / f'{self.path.stem}-{rotate+1}.jpg' cv2.imwrite(self.path, img) def get_base64(self, rotate=None): print(self.path) img = cv2.imread(str(self.path)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if rotate is not None: img = cv2.rotate(img, rotate) self.save_image(img) _, img = cv2.imencode('.jpg', img) img_str = base64.b64encode(img).decode('utf-8') return img_str def get_json(self): with open(self.json_path, 'r') as f: return json.load(f) def send_request(base64_str, config_str, image_type=None): config = CONFIGS[config_str] headers = { 'Content-Type': 'application/json', 'Authorization': config.token } data = { 'image': base64_str, } if image_type: data['image_type'] = image_type response = requests.post(config.url, headers=headers, json=data) return response.json() class Dataset(object): def __init__(self, image_path, image_type, config_str, rotate=False): self.image_type = image_type self.config_str = config_str self.image_path = image_path self.image_list = [] for p in chain(*[Path(self.image_path).rglob('*.jpg')]): if rotate: for rotate in [None, 0, 1, 2]: self.image_list.append(Image(p, rotate)) else: self.image_list.append(Image(p, None)) self.attrs = ['orientation','name','id','language', 'level', 'exam_time', 'score'] self.correct = {k: 0 for k in self.attrs} self.error = {k: 0 for k in self.attrs} def __len__(self): return len(self.image_list) def _evaluate_one(self, image: Image): def _get_predict(r, key): if isinstance(r[key], dict): return r[key]['text'] else: return r[key] base64_str = image.get_base64() r = send_request(base64_str, self.config_str, self.image_type) print(r) err_str = '' if r['status'] == '000': res = r['result'] for key in self.attrs: if key in res: gt = image.get_gt_result(key) predict = _get_predict(res, key) print(f'gt: {gt}, predict: {predict}') if predict == gt: self.correct[key] += 1 else: image.cate = False self.error[key] += 1 err_str += f'正确:{gt}
返回:{predict}
' if image.cate: image.ocr_result = r['result'] else: image.ocr_result = err_str else: image.ocr_result = r['msg'] image.cate = False for key in self.attrs: self.error[key] += 1 def evaluate(self): for image in tqdm(self.image_list): self._evaluate_one(image) def accuracy(self): print(self.correct.values()) return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values())) def attrs_accuracy(self): return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs} if __name__ == '__main__': dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, False) print(len(dataset)) for d in dataset.image_list: print(d) dataset.evaluate() print(dataset.accuracy())