123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- from pathlib import Path
- from typing import List, Optional
- import cv2
- import requests
- from dataclasses import dataclass
- import json
- import time
- import base64
- from itertools import chain
- from tqdm import tqdm
- @dataclass
- class RequestConfig:
- url: str
- token: str
- local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
- test_config = RequestConfig(
- url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
- token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
- sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
- token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
- CONFIGS = {
- 'local': local_config,
- 'test': test_config,
- 'sb': sb_config
- }
- CONFIG_STR = 'local'
- IMAGE_TYPE = None
- IMAGE_PATH = Path('images/cet6/')
- # class MarkdownTable(object):
- # def __init__(self, name):
- # self.name = name
- # self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
- # self.field_table = ['字段', '正确率']
- # self.true_table = ['图片', '识别结果']
- # self.false_table = ['图片', '识别结果']
- # def add_field_table(self, fields: List):
- # self.field_table.extend(fields)
- # def add_true_table(self, image_and_field: List):
- # self.true_table.extend(image_and_field)
- # def add_false_table(self, image_and_field: List):
- # self.false_table.extend(image_and_field)
- class Image:
- def __init__(self, path: Path, rotate):
- self._path = path
- self.rotate = rotate
- self._ocr_result = None
- self.cate = True
- try:
- self.gt_result = self.get_json()
- except Exception as e:
- print(self.json_path)
- raise e
- def __repr__(self):
- return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
- @property
- def path(self):
- return self._path
- @path.setter
- def path(self, path):
- self._path = path
- @property
- def fn(self):
- return self._path.stem
- @property
- def ocr_result(self):
- return self._ocr_result
- @ocr_result.setter
- def ocr_result(self, value):
- self._ocr_result = value
- def get_gt_result(self, key):
- if key == 'orientation':
- return self.rotate + 1 if self.rotate is not None else 0
- elif key in self.gt_result:
- return self.gt_result[key]
- else:
- return None
- @property
- def json_path(self):
- return self.path.parent / f'{self.path.stem}.json'
- def save_image(self, img, rotate):
- dst = self.path.parent.parent / (".ro_dst")
- if not dst.exists(): dst.mkdir()
- self.path = dst / f'{self.path.stem}-{rotate+1}.jpg'
- print('save image', self.path)
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- cv2.imwrite(str(self.path), img)
- def get_base64(self, rotate=None):
- print(self.path)
- img = cv2.imread(str(self.path))
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- if rotate is not None:
- img = cv2.rotate(img, rotate)
- self.save_image(img, rotate)
- _, img = cv2.imencode('.jpg', img)
- img_str = base64.b64encode(img).decode('utf-8')
- return img_str
- def get_json(self):
- with open(self.json_path, 'r') as f:
- return json.load(f)
- def send_request(image: Image, config_str, image_type=None):
- base64_str = image.get_base64(image.rotate)
- config = CONFIGS[config_str]
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': config.token
- }
- data = {
- 'image': base64_str,
- }
- if image_type:
- data['image_type'] = image_type
- response = requests.post(config.url, headers=headers, json=data)
- return response.json()
- class Dataset(object):
- def __init__(self, image_path, image_type, config_str, rotate=False):
- self.image_type = image_type
- self.config_str = config_str
- self.image_path = image_path
- self.image_list = []
- for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
- if rotate:
- for r in [None, 0, 1, 2]:
- self.image_list.append(Image(p, r))
- else:
- self.image_list.append(Image(p, None))
- self.attrs = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
- self.correct = {k: 0 for k in self.attrs}
- self.error = {k: 0 for k in self.attrs}
- def __len__(self):
- return len(self.image_list)
- def _evaluate_one(self, image: Image):
- def _get_predict(r, key):
- if isinstance(r[key], dict):
- return r[key]['text']
- else:
- return r[key]
- r = send_request(image, self.config_str, self.image_type)
- err_str = ''
- if r['status'] == '000':
- res = r['result']
- for key in self.attrs:
- print('attr: ', key)
- if key in res:
- gt = image.get_gt_result(key)
- predict = _get_predict(res, key)
- print(f'gt: {gt}, predict: {predict}')
- if predict == gt:
- self.correct[key] += 1
- else:
- image.cate = False
- self.error[key] += 1
- err_str += f'正确:{gt}<br>返回:{predict}<br>'
- if image.cate:
- image.ocr_result = r['result']
- else:
- image.ocr_result = err_str
- else:
- image.ocr_result = r['msg']
- image.cate = False
- for key in self.attrs:
- self.error[key] += 1
- def __call__(self):
- for image in self.image_list:
- yield image
- def evaluate(self):
- for image in tqdm(self.image_list):
- self._evaluate_one(image)
- def accuracy(self):
- return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
- def attrs_accuracy(self):
- return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
- if __name__ == '__main__':
- dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, True)
- print(len(dataset))
- for d in dataset():
- print(d)
- dataset.evaluate()
- print(dataset.accuracy())
|