|
@@ -1,9 +1,5 @@
|
|
|
-import operator
|
|
|
from pathlib import Path
|
|
|
-from typing import List
|
|
|
-
|
|
|
-import numpy as np
|
|
|
-from mdutils.mdutils import MdUtils
|
|
|
+from typing import List, Optional
|
|
|
import cv2
|
|
|
import requests
|
|
|
from dataclasses import dataclass
|
|
@@ -11,148 +7,211 @@ import json
|
|
|
import time
|
|
|
import base64
|
|
|
from itertools import chain
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
+@dataclass
|
|
|
+class RequestConfig:
|
|
|
+ url: str
|
|
|
+ token: str
|
|
|
|
|
|
-class MarkdownTable(object):
|
|
|
- def __init__(self, name):
|
|
|
- self.name = name
|
|
|
- self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
|
|
|
- self.field_table = ['字段', '正确率']
|
|
|
- self.true_table = ['图片', '识别结果']
|
|
|
- self.false_table = ['图片', '识别结果']
|
|
|
|
|
|
- def add_field_table(self, fields: List):
|
|
|
- self.field_table.extend(fields)
|
|
|
|
|
|
- def add_true_table(self, image_and_field: List):
|
|
|
- self.true_table.extend(image_and_field)
|
|
|
+local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
|
|
|
+test_config = RequestConfig(url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
|
|
|
+sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
|
|
|
|
|
|
- def add_false_table(self, image_and_field: List):
|
|
|
- self.false_table.extend(image_and_field)
|
|
|
+CONFIGS = {
|
|
|
+ 'local': local_config,
|
|
|
+ 'test': test_config,
|
|
|
+ 'sb': sb_config
|
|
|
+}
|
|
|
|
|
|
+CONFIG_STR = 'local'
|
|
|
|
|
|
-@dataclass
|
|
|
-class Image:
|
|
|
- path: Path
|
|
|
- rotate: int
|
|
|
+IMAGE_TYPE = None
|
|
|
|
|
|
+IMAGE_PATH = Path('images/cet6/')
|
|
|
|
|
|
- @property
|
|
|
- def fn(self):
|
|
|
- return self.path.stem
|
|
|
|
|
|
- @property
|
|
|
- def json_path(self):
|
|
|
- return self.path.parent / f'{self.path.stem}.json'
|
|
|
+# class MarkdownTable(object):
|
|
|
+# def __init__(self, name):
|
|
|
+# self.name = name
|
|
|
+# self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
|
|
|
+# self.field_table = ['字段', '正确率']
|
|
|
+# self.true_table = ['图片', '识别结果']
|
|
|
+# self.false_table = ['图片', '识别结果']
|
|
|
|
|
|
- def get_base64(self, rotate=0):
|
|
|
- return 'dsf'
|
|
|
+# def add_field_table(self, fields: List):
|
|
|
+# self.field_table.extend(fields)
|
|
|
|
|
|
+# def add_true_table(self, image_and_field: List):
|
|
|
+# self.true_table.extend(image_and_field)
|
|
|
|
|
|
+# def add_false_table(self, image_and_field: List):
|
|
|
+# self.false_table.extend(image_and_field)
|
|
|
+
|
|
|
+
|
|
|
+class Image:
|
|
|
+ def __init__(self, path: Path, rotate):
|
|
|
+ self._path = path
|
|
|
+ self.rotate = rotate
|
|
|
+ self._ocr_result = None
|
|
|
+ self.cate = True
|
|
|
+ try:
|
|
|
+ self.gt_result = self.get_json()
|
|
|
+ except Exception as e:
|
|
|
+ print(self.json_path)
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
|
|
|
+
|
|
|
+ @property
|
|
|
+ def path(self):
|
|
|
+ return self._path
|
|
|
+
|
|
|
+ @path.setter
|
|
|
+ def path(self, path):
|
|
|
+ self._path = path
|
|
|
+
|
|
|
+ @property
|
|
|
+ def fn(self):
|
|
|
+ return self._path.stem
|
|
|
|
|
|
+ @property
|
|
|
+ def ocr_result(self):
|
|
|
+ return self._ocr_result
|
|
|
|
|
|
+ @ocr_result.setter
|
|
|
+ def ocr_result(self, value):
|
|
|
+ self._ocr_result = value
|
|
|
|
|
|
+ def get_gt_result(self, key):
|
|
|
+ if key in self.gt_result:
|
|
|
+ return self.gt_result[key]
|
|
|
+ else:
|
|
|
+ return None
|
|
|
|
|
|
+ @property
|
|
|
+ def json_path(self):
|
|
|
+ return self.path.parent / f'{self.path.stem}.json'
|
|
|
|
|
|
-class DataSet(object):
|
|
|
- def __init__(self, image_path, rotate=False):
|
|
|
+ def save_image(self, img, rotate=None):
|
|
|
+ dire = self.path.parent / (".ro_dire")
|
|
|
+ if not dire.exists(): dire.mkdir()
|
|
|
+ self.path = dire / f'{self.path.stem}-{rotate+1}.jpg'
|
|
|
+ cv2.imwrite(self.path, img)
|
|
|
+
|
|
|
+ def get_base64(self, rotate=None):
|
|
|
+ print(self.path)
|
|
|
+ img = cv2.imread(str(self.path))
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
+ if rotate is not None:
|
|
|
+ img = cv2.rotate(img, rotate)
|
|
|
+ self.save_image(img)
|
|
|
+ _, img = cv2.imencode('.jpg', img)
|
|
|
+ img_str = base64.b64encode(img).decode('utf-8')
|
|
|
+ return img_str
|
|
|
+
|
|
|
+ def get_json(self):
|
|
|
+ with open(self.json_path, 'r') as f:
|
|
|
+ return json.load(f)
|
|
|
+
|
|
|
+
|
|
|
+def send_request(base64_str, config_str, image_type=None):
|
|
|
+ config = CONFIGS[config_str]
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json',
|
|
|
+ 'Authorization': config.token
|
|
|
+ }
|
|
|
+ data = {
|
|
|
+ 'image': base64_str,
|
|
|
+ }
|
|
|
+ if image_type:
|
|
|
+ data['image_type'] = image_type
|
|
|
+ response = requests.post(config.url, headers=headers, json=data)
|
|
|
+ return response.json()
|
|
|
+
|
|
|
+
|
|
|
+class Dataset(object):
|
|
|
+ def __init__(self, image_path, image_type, config_str, rotate=False):
|
|
|
+ self.image_type = image_type
|
|
|
+ self.config_str = config_str
|
|
|
self.image_path = image_path
|
|
|
self.image_list = []
|
|
|
for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
|
|
|
- self.image_list.append(Image(p, 0))
|
|
|
+ if rotate:
|
|
|
+ for rotate in [None, 0, 1, 2]:
|
|
|
+ self.image_list.append(Image(p, rotate))
|
|
|
+ else:
|
|
|
+ self.image_list.append(Image(p, None))
|
|
|
+
|
|
|
|
|
|
self.attrs = ['orientation','name','id','language', 'level', 'exam_time', 'score']
|
|
|
|
|
|
- self.tp = {k: 0 for k in self.attrs}
|
|
|
+ self.correct = {k: 0 for k in self.attrs}
|
|
|
+ self.error = {k: 0 for k in self.attrs}
|
|
|
|
|
|
|
|
|
- # self.field_rate = {
|
|
|
- # 'orientation': self.count,
|
|
|
- # 'name': self.count,
|
|
|
- # 'id': self.count,
|
|
|
- # 'language': self.count,
|
|
|
- # 'level': self.count,
|
|
|
- # 'exam_time': self.count,
|
|
|
- # 'score': self.count,
|
|
|
- # }
|
|
|
- # self.del_field = del_field
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.image_list)
|
|
|
|
|
|
|
|
|
def _evaluate_one(self, image: Image):
|
|
|
+ def _get_predict(r, key):
|
|
|
+ if isinstance(r[key], dict):
|
|
|
+ return r[key]['text']
|
|
|
+ else:
|
|
|
+ return r[key]
|
|
|
|
|
|
|
|
|
+ base64_str = image.get_base64()
|
|
|
+ r = send_request(base64_str, self.config_str, self.image_type)
|
|
|
+ print(r)
|
|
|
+ err_str = ''
|
|
|
|
|
|
- @property
|
|
|
- def count(self):
|
|
|
- return len(list(chain(*[self.image_paths.rglob('*.jpg')]))) * 4 if self.is_rotate else 1
|
|
|
+ if r['status'] == '000':
|
|
|
+ res = r['result']
|
|
|
+ for key in self.attrs:
|
|
|
+ if key in res:
|
|
|
+ gt = image.get_gt_result(key)
|
|
|
+ predict = _get_predict(res, key)
|
|
|
+ print(f'gt: {gt}, predict: {predict}')
|
|
|
+ if predict == gt:
|
|
|
+ self.correct[key] += 1
|
|
|
+ else:
|
|
|
+ image.cate = False
|
|
|
+ self.error[key] += 1
|
|
|
+ err_str += f'正确:{gt}<br>返回:{predict}<br>'
|
|
|
+ if image.cate:
|
|
|
+ image.ocr_result = r['result']
|
|
|
+ else:
|
|
|
+ image.ocr_result = err_str
|
|
|
+ else:
|
|
|
+ image.ocr_result = r['msg']
|
|
|
+ image.cate = False
|
|
|
+ for key in self.attrs:
|
|
|
+ self.error[key] += 1
|
|
|
|
|
|
- @property
|
|
|
- def images(self):
|
|
|
- images_path_list = list(chain(*[Path(self.image_paths).rglob('*.jpg')]))
|
|
|
- images_path_dict = {path.name: [str(path), str(path.parent / f'{path.stem}.json')] for path in images_path_list}
|
|
|
- return {i: images_path_dict[i] for i in sorted(images_path_dict)}
|
|
|
+ def evaluate(self):
|
|
|
+ for image in tqdm(self.image_list):
|
|
|
+ self._evaluate_one(image)
|
|
|
|
|
|
- def revise_field_rate(self, field):
|
|
|
- self.field_rate[field] = self.field_rate[field] - 1
|
|
|
|
|
|
- @property
|
|
|
- def field_rate_2_list(self):
|
|
|
- table = []
|
|
|
- for k, v in self.field_rate.items():
|
|
|
- table.extend((k, "{:.2f}%".format(v / self.count * 100)))
|
|
|
- return table
|
|
|
-
|
|
|
- def image_2_base64(self, image):
|
|
|
- dire = self.image_paths.parent / (".ro_dire")
|
|
|
- if not dire.exists(): dire.mkdir()
|
|
|
+ def accuracy(self):
|
|
|
+ print(self.correct.values())
|
|
|
+ return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
|
|
|
|
|
|
- image_path = Path(self.images[image][0])
|
|
|
-
|
|
|
- if self.is_rotate:
|
|
|
- img = cv2.imread(str(image_path))
|
|
|
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
- base64_imgs = []
|
|
|
- for rotate in {cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE}:
|
|
|
- img_rotated = cv2.rotate(img, rotate)
|
|
|
- img_rotated = cv2.cvtColor(img_rotated, cv2.COLOR_BGR2RGB)
|
|
|
- img_rotated_path = dire / f"{image_path.stem}_{str(rotate + 1)}.jpg"
|
|
|
- cv2.imread(str(img_rotated_path), img_rotated)
|
|
|
- with img_rotated_path.open('rb') as f:
|
|
|
- img_str: str = base64.encodebytes(f.read()).decode('utf-8')
|
|
|
- base64_imgs.extend(img_str)
|
|
|
- return base64_imgs
|
|
|
- else:
|
|
|
- with image_path.open('rb') as f:
|
|
|
- img_str: str = base64.encodebytes(f.read()).decode('utf-8')
|
|
|
- return [img_str]
|
|
|
|
|
|
- def res_2_dict(self, r):
|
|
|
- if r['status'] == '000':
|
|
|
- r = r['result']
|
|
|
- if r:
|
|
|
- del r['confidence']
|
|
|
- if self.del_field is not None: del r[self.del_field]
|
|
|
- return {k: v['text'] if isinstance(v, dict) else v for k, v in r.items()}
|
|
|
- elif r['status'] == '101':
|
|
|
- return r['msg']
|
|
|
-
|
|
|
- def json_2_dict(self, image):
|
|
|
- json_path = Path(self.images[image][1])
|
|
|
- with json_path.open('r') as f:
|
|
|
- json_dict = json.load(f)
|
|
|
- if self.del_field is not None: del json_dict[self.del_field]
|
|
|
- return json_dict
|
|
|
-
|
|
|
- def compare_dict(self, MT: MarkdownTable, res_dict, json_dict, image_path):
|
|
|
- image_mark = MT.mdFile.new_inline_image(text='', path=image_path)
|
|
|
- if operator.eq(res_dict, json_dict):
|
|
|
- MT.add_true_table([image_mark, json_dict])
|
|
|
- elif type(res_dict) == dict:
|
|
|
- err_str = ""
|
|
|
- for key in res_dict:
|
|
|
- if res_dict[key] != res_dict[key]:
|
|
|
- err_str = f"{err_str}正确:{res_dict[key]}<br>返回:{res_dict[key]}<br>"
|
|
|
- self.revise_field_rate(key)
|
|
|
- MT.add_false_table(([image_mark, err_str]))
|
|
|
- elif type(res_dict) == str:
|
|
|
- MT.add_false_table([image_mark, res_dict])
|
|
|
+ def attrs_accuracy(self):
|
|
|
+ return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, False)
|
|
|
+ print(len(dataset))
|
|
|
+ for d in dataset.image_list:
|
|
|
+ print(d)
|
|
|
+
|
|
|
+ dataset.evaluate()
|
|
|
+ print(dataset.accuracy())
|