new.py 6.4 KB


  1. from pathlib import Path
  2. from typing import List, Optional
  3. import cv2
  4. import requests
  5. from dataclasses import dataclass
  6. import json
  7. import time
  8. import base64
  9. from itertools import chain
  10. from tqdm import tqdm
  11. @dataclass
  12. class RequestConfig:
  13. url: str
  14. token: str
  15. local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
  16. test_config = RequestConfig(
  17. url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
  18. token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
  19. sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
  20. token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
  21. CONFIGS = {
  22. 'local': local_config,
  23. 'test': test_config,
  24. 'sb': sb_config
  25. }
  26. CONFIG_STR = 'local'
  27. IMAGE_TYPE = None
  28. IMAGE_PATH = Path('images/cet6/')
  29. # class MarkdownTable(object):
  30. # def __init__(self, name):
  31. # self.name = name
  32. # self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
  33. # self.field_table = ['字段', '正确率']
  34. # self.true_table = ['图片', '识别结果']
  35. # self.false_table = ['图片', '识别结果']
  36. # def add_field_table(self, fields: List):
  37. # self.field_table.extend(fields)
  38. # def add_true_table(self, image_and_field: List):
  39. # self.true_table.extend(image_and_field)
  40. # def add_false_table(self, image_and_field: List):
  41. # self.false_table.extend(image_and_field)
  42. class Image:
  43. def __init__(self, path: Path, rotate):
  44. self._path = path
  45. self.rotate = rotate
  46. self._ocr_result = None
  47. self.cate = True
  48. try:
  49. self.gt_result = self.get_json()
  50. except Exception as e:
  51. print(self.json_path)
  52. raise e
  53. def __repr__(self):
  54. return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
  55. @property
  56. def path(self):
  57. return self._path
  58. @path.setter
  59. def path(self, path):
  60. self._path = path
  61. @property
  62. def fn(self):
  63. return self._path.stem
  64. @property
  65. def ocr_result(self):
  66. return self._ocr_result
  67. @ocr_result.setter
  68. def ocr_result(self, value):
  69. self._ocr_result = value
  70. def get_gt_result(self, key):
  71. if key == 'orientation':
  72. return self.rotate + 1 if self.rotate is not None else 0
  73. elif key in self.gt_result:
  74. return self.gt_result[key]
  75. else:
  76. return None
  77. @property
  78. def json_path(self):
  79. return self.path.parent / f'{self.path.stem}.json'
  80. def save_image(self, img, rotate):
  81. dst = self.path.parent.parent / (".ro_dst")
  82. if not dst.exists(): dst.mkdir()
  83. self.path = dst / f'{self.path.stem}-{rotate+1}.jpg'
  84. print('save image', self.path)
  85. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  86. cv2.imwrite(str(self.path), img)
  87. def get_base64(self, rotate=None):
  88. print(self.path)
  89. img = cv2.imread(str(self.path))
  90. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  91. if rotate is not None:
  92. img = cv2.rotate(img, rotate)
  93. self.save_image(img, rotate)
  94. _, img = cv2.imencode('.jpg', img)
  95. img_str = base64.b64encode(img).decode('utf-8')
  96. return img_str
  97. def get_json(self):
  98. with open(self.json_path, 'r') as f:
  99. return json.load(f)
  100. def send_request(image: Image, config_str, image_type=None):
  101. base64_str = image.get_base64(image.rotate)
  102. config = CONFIGS[config_str]
  103. headers = {
  104. 'Content-Type': 'application/json',
  105. 'Authorization': config.token
  106. }
  107. data = {
  108. 'image': base64_str,
  109. }
  110. if image_type:
  111. data['image_type'] = image_type
  112. response = requests.post(config.url, headers=headers, json=data)
  113. return response.json()
  114. class Dataset(object):
  115. def __init__(self, image_path, image_type, config_str, rotate=False):
  116. self.image_type = image_type
  117. self.config_str = config_str
  118. self.image_path = image_path
  119. self.image_list = []
  120. for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
  121. if rotate:
  122. for r in [None, 0, 1, 2]:
  123. self.image_list.append(Image(p, r))
  124. else:
  125. self.image_list.append(Image(p, None))
  126. self.attrs = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
  127. self.correct = {k: 0 for k in self.attrs}
  128. self.error = {k: 0 for k in self.attrs}
  129. def __len__(self):
  130. return len(self.image_list)
  131. def _evaluate_one(self, image: Image):
  132. def _get_predict(r, key):
  133. if isinstance(r[key], dict):
  134. return r[key]['text']
  135. else:
  136. return r[key]
  137. r = send_request(image, self.config_str, self.image_type)
  138. err_str = ''
  139. if r['status'] == '000':
  140. res = r['result']
  141. for key in self.attrs:
  142. print('attr: ', key)
  143. if key in res:
  144. gt = image.get_gt_result(key)
  145. predict = _get_predict(res, key)
  146. print(f'gt: {gt}, predict: {predict}')
  147. if predict == gt:
  148. self.correct[key] += 1
  149. else:
  150. image.cate = False
  151. self.error[key] += 1
  152. err_str += f'正确:{gt}<br>返回:{predict}<br>'
  153. if image.cate:
  154. image.ocr_result = r['result']
  155. else:
  156. image.ocr_result = err_str
  157. else:
  158. image.ocr_result = r['msg']
  159. image.cate = False
  160. for key in self.attrs:
  161. self.error[key] += 1
  162. def __call__(self):
  163. for image in self.image_list:
  164. yield image
  165. def evaluate(self):
  166. for image in tqdm(self.image_list):
  167. self._evaluate_one(image)
  168. def accuracy(self):
  169. return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
  170. def attrs_accuracy(self):
  171. return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
  172. if __name__ == '__main__':
  173. dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, True)
  174. print(len(dataset))
  175. for d in dataset():
  176. print(d)
  177. dataset.evaluate()
  178. print(dataset.accuracy())