|
@@ -1,222 +0,0 @@
|
|
|
-from pathlib import Path
|
|
|
-from typing import List, Optional
|
|
|
-import cv2
|
|
|
-import requests
|
|
|
-from dataclasses import dataclass
|
|
|
-import json
|
|
|
-import time
|
|
|
-import base64
|
|
|
-from itertools import chain
|
|
|
-from tqdm import tqdm
|
|
|
-
|
|
|
-
|
|
|
-@dataclass
|
|
|
-class RequestConfig:
|
|
|
- url: str
|
|
|
- token: str
|
|
|
-
|
|
|
-
|
|
|
-local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
|
|
|
-test_config = RequestConfig(
|
|
|
- url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
|
|
|
- token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
|
|
|
-sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
|
|
|
- token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
|
|
|
-
|
|
|
-CONFIGS = {
|
|
|
- 'local': local_config,
|
|
|
- 'test': test_config,
|
|
|
- 'sb': sb_config
|
|
|
-}
|
|
|
-
|
|
|
-CONFIG_STR = 'local'
|
|
|
-
|
|
|
-IMAGE_TYPE = None
|
|
|
-
|
|
|
-IMAGE_PATH = Path('images/cet6/')
|
|
|
-
|
|
|
-
|
|
|
-# class MarkdownTable(object):
|
|
|
-# def __init__(self, name):
|
|
|
-# self.name = name
|
|
|
-# self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
|
|
|
-# self.field_table = ['字段', '正确率']
|
|
|
-# self.true_table = ['图片', '识别结果']
|
|
|
-# self.false_table = ['图片', '识别结果']
|
|
|
-
|
|
|
-# def add_field_table(self, fields: List):
|
|
|
-# self.field_table.extend(fields)
|
|
|
-
|
|
|
-# def add_true_table(self, image_and_field: List):
|
|
|
-# self.true_table.extend(image_and_field)
|
|
|
-
|
|
|
-# def add_false_table(self, image_and_field: List):
|
|
|
-# self.false_table.extend(image_and_field)
|
|
|
-
|
|
|
-
|
|
|
-class Image:
|
|
|
- def __init__(self, path: Path, rotate):
|
|
|
- self._path = path
|
|
|
- self.rotate = rotate
|
|
|
- self._ocr_result = None
|
|
|
- self.cate = True
|
|
|
- try:
|
|
|
- self.gt_result = self.get_json()
|
|
|
- except Exception as e:
|
|
|
- print(self.json_path)
|
|
|
- raise e
|
|
|
-
|
|
|
- def __repr__(self):
|
|
|
- return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
|
|
|
-
|
|
|
- @property
|
|
|
- def path(self):
|
|
|
- return self._path
|
|
|
-
|
|
|
- @path.setter
|
|
|
- def path(self, path):
|
|
|
- self._path = path
|
|
|
-
|
|
|
- @property
|
|
|
- def fn(self):
|
|
|
- return self._path.stem
|
|
|
-
|
|
|
- @property
|
|
|
- def ocr_result(self):
|
|
|
- return self._ocr_result
|
|
|
-
|
|
|
- @ocr_result.setter
|
|
|
- def ocr_result(self, value):
|
|
|
- self._ocr_result = value
|
|
|
-
|
|
|
- def get_gt_result(self, key):
|
|
|
- if key == 'orientation':
|
|
|
- return self.rotate + 1 if self.rotate is not None else 0
|
|
|
- elif key in self.gt_result:
|
|
|
- return self.gt_result[key]
|
|
|
- else:
|
|
|
- return None
|
|
|
-
|
|
|
-
|
|
|
- @property
|
|
|
- def json_path(self):
|
|
|
- return self.path.parent / f'{self.path.stem}.json'
|
|
|
-
|
|
|
- def save_image(self, img, rotate):
|
|
|
- dst = self.path.parent.parent / (".ro_dst")
|
|
|
- if not dst.exists(): dst.mkdir()
|
|
|
- self.path = dst / f'{self.path.stem}-{rotate+1}.jpg'
|
|
|
- print('save image', self.path)
|
|
|
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
- cv2.imwrite(str(self.path), img)
|
|
|
-
|
|
|
-
|
|
|
- def get_base64(self, rotate=None):
|
|
|
- print(self.path)
|
|
|
- img = cv2.imread(str(self.path))
|
|
|
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
- if rotate is not None:
|
|
|
- img = cv2.rotate(img, rotate)
|
|
|
- self.save_image(img, rotate)
|
|
|
- _, img = cv2.imencode('.jpg', img)
|
|
|
- img_str = base64.b64encode(img).decode('utf-8')
|
|
|
- return img_str
|
|
|
-
|
|
|
- def get_json(self):
|
|
|
- with open(self.json_path, 'r') as f:
|
|
|
- return json.load(f)
|
|
|
-
|
|
|
-
|
|
|
-def send_request(image: Image, config_str, image_type=None):
|
|
|
- base64_str = image.get_base64(image.rotate)
|
|
|
- config = CONFIGS[config_str]
|
|
|
- headers = {
|
|
|
- 'Content-Type': 'application/json',
|
|
|
- 'Authorization': config.token
|
|
|
- }
|
|
|
- data = {
|
|
|
- 'image': base64_str,
|
|
|
- }
|
|
|
- if image_type:
|
|
|
- data['image_type'] = image_type
|
|
|
- response = requests.post(config.url, headers=headers, json=data)
|
|
|
- return response.json()
|
|
|
-
|
|
|
-
|
|
|
-class Dataset(object):
|
|
|
- def __init__(self, image_path, image_type, config_str, rotate=False):
|
|
|
- self.image_type = image_type
|
|
|
- self.config_str = config_str
|
|
|
- self.image_path = image_path
|
|
|
- self.image_list = []
|
|
|
- for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
|
|
|
- if rotate:
|
|
|
- for r in [None, 0, 1, 2]:
|
|
|
- self.image_list.append(Image(p, r))
|
|
|
- else:
|
|
|
- self.image_list.append(Image(p, None))
|
|
|
-
|
|
|
- self.attrs = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
|
|
|
-
|
|
|
- self.correct = {k: 0 for k in self.attrs}
|
|
|
- self.error = {k: 0 for k in self.attrs}
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.image_list)
|
|
|
-
|
|
|
- def _evaluate_one(self, image: Image):
|
|
|
- def _get_predict(r, key):
|
|
|
- if isinstance(r[key], dict):
|
|
|
- return r[key]['text']
|
|
|
- else:
|
|
|
- return r[key]
|
|
|
-
|
|
|
- r = send_request(image, self.config_str, self.image_type)
|
|
|
- err_str = ''
|
|
|
- if r['status'] == '000':
|
|
|
- res = r['result']
|
|
|
- for key in self.attrs:
|
|
|
- print('attr: ', key)
|
|
|
- if key in res:
|
|
|
- gt = image.get_gt_result(key)
|
|
|
- predict = _get_predict(res, key)
|
|
|
- print(f'gt: {gt}, predict: {predict}')
|
|
|
- if predict == gt:
|
|
|
- self.correct[key] += 1
|
|
|
- else:
|
|
|
- image.cate = False
|
|
|
- self.error[key] += 1
|
|
|
- err_str += f'正确:{gt}<br>返回:{predict}<br>'
|
|
|
- if image.cate:
|
|
|
- image.ocr_result = r['result']
|
|
|
- else:
|
|
|
- image.ocr_result = err_str
|
|
|
- else:
|
|
|
- image.ocr_result = r['msg']
|
|
|
- image.cate = False
|
|
|
- for key in self.attrs:
|
|
|
- self.error[key] += 1
|
|
|
-
|
|
|
- def __call__(self):
|
|
|
- for image in self.image_list:
|
|
|
- yield image
|
|
|
-
|
|
|
- def evaluate(self):
|
|
|
- for image in tqdm(self.image_list):
|
|
|
- self._evaluate_one(image)
|
|
|
-
|
|
|
- def accuracy(self):
|
|
|
- return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
|
|
|
-
|
|
|
- def attrs_accuracy(self):
|
|
|
- return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
|
|
|
-
|
|
|
-
|
|
|
-if __name__ == '__main__':
|
|
|
- dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, True)
|
|
|
- print(len(dataset))
|
|
|
- for d in dataset():
|
|
|
- print(d)
|
|
|
-
|
|
|
- dataset.evaluate()
|
|
|
- print(dataset.accuracy())
|