new.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from pathlib import Path
  2. from typing import List, Optional
  3. import cv2
  4. import requests
  5. from dataclasses import dataclass
  6. import json
  7. import time
  8. import base64
  9. from itertools import chain
  10. from tqdm import tqdm
  11. @dataclass
  12. class RequestConfig:
  13. url: str
  14. token: str
  15. local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
  16. test_config = RequestConfig(url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
  17. sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
  18. CONFIGS = {
  19. 'local': local_config,
  20. 'test': test_config,
  21. 'sb': sb_config
  22. }
  23. CONFIG_STR = 'local'
  24. IMAGE_TYPE = None
  25. IMAGE_PATH = Path('images/cet6/')
  26. # class MarkdownTable(object):
  27. # def __init__(self, name):
  28. # self.name = name
  29. # self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
  30. # self.field_table = ['字段', '正确率']
  31. # self.true_table = ['图片', '识别结果']
  32. # self.false_table = ['图片', '识别结果']
  33. # def add_field_table(self, fields: List):
  34. # self.field_table.extend(fields)
  35. # def add_true_table(self, image_and_field: List):
  36. # self.true_table.extend(image_and_field)
  37. # def add_false_table(self, image_and_field: List):
  38. # self.false_table.extend(image_and_field)
  39. class Image:
  40. def __init__(self, path: Path, rotate):
  41. self._path = path
  42. self.rotate = rotate
  43. self._ocr_result = None
  44. self.cate = True
  45. try:
  46. self.gt_result = self.get_json()
  47. except Exception as e:
  48. print(self.json_path)
  49. raise e
  50. def __repr__(self):
  51. return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
  52. @property
  53. def path(self):
  54. return self._path
  55. @path.setter
  56. def path(self, path):
  57. self._path = path
  58. @property
  59. def fn(self):
  60. return self._path.stem
  61. @property
  62. def ocr_result(self):
  63. return self._ocr_result
  64. @ocr_result.setter
  65. def ocr_result(self, value):
  66. self._ocr_result = value
  67. def get_gt_result(self, key):
  68. if key in self.gt_result:
  69. return self.gt_result[key]
  70. else:
  71. return None
  72. @property
  73. def json_path(self):
  74. return self.path.parent / f'{self.path.stem}.json'
  75. def save_image(self, img, rotate=None):
  76. dire = self.path.parent / (".ro_dire")
  77. if not dire.exists(): dire.mkdir()
  78. self.path = dire / f'{self.path.stem}-{rotate+1}.jpg'
  79. cv2.imwrite(self.path, img)
  80. def get_base64(self, rotate=None):
  81. print(self.path)
  82. img = cv2.imread(str(self.path))
  83. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  84. if rotate is not None:
  85. img = cv2.rotate(img, rotate)
  86. self.save_image(img)
  87. _, img = cv2.imencode('.jpg', img)
  88. img_str = base64.b64encode(img).decode('utf-8')
  89. return img_str
  90. def get_json(self):
  91. with open(self.json_path, 'r') as f:
  92. return json.load(f)
  93. def send_request(base64_str, config_str, image_type=None):
  94. config = CONFIGS[config_str]
  95. headers = {
  96. 'Content-Type': 'application/json',
  97. 'Authorization': config.token
  98. }
  99. data = {
  100. 'image': base64_str,
  101. }
  102. if image_type:
  103. data['image_type'] = image_type
  104. response = requests.post(config.url, headers=headers, json=data)
  105. return response.json()
  106. class Dataset(object):
  107. def __init__(self, image_path, image_type, config_str, rotate=False):
  108. self.image_type = image_type
  109. self.config_str = config_str
  110. self.image_path = image_path
  111. self.image_list = []
  112. for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
  113. if rotate:
  114. for rotate in [None, 0, 1, 2]:
  115. self.image_list.append(Image(p, rotate))
  116. else:
  117. self.image_list.append(Image(p, None))
  118. self.attrs = ['orientation','name','id','language', 'level', 'exam_time', 'score']
  119. self.correct = {k: 0 for k in self.attrs}
  120. self.error = {k: 0 for k in self.attrs}
  121. def __len__(self):
  122. return len(self.image_list)
  123. def _evaluate_one(self, image: Image):
  124. def _get_predict(r, key):
  125. if isinstance(r[key], dict):
  126. return r[key]['text']
  127. else:
  128. return r[key]
  129. base64_str = image.get_base64()
  130. r = send_request(base64_str, self.config_str, self.image_type)
  131. print(r)
  132. err_str = ''
  133. if r['status'] == '000':
  134. res = r['result']
  135. for key in self.attrs:
  136. if key in res:
  137. gt = image.get_gt_result(key)
  138. predict = _get_predict(res, key)
  139. print(f'gt: {gt}, predict: {predict}')
  140. if predict == gt:
  141. self.correct[key] += 1
  142. else:
  143. image.cate = False
  144. self.error[key] += 1
  145. err_str += f'正确:{gt}<br>返回:{predict}<br>'
  146. if image.cate:
  147. image.ocr_result = r['result']
  148. else:
  149. image.ocr_result = err_str
  150. else:
  151. image.ocr_result = r['msg']
  152. image.cate = False
  153. for key in self.attrs:
  154. self.error[key] += 1
  155. def evaluate(self):
  156. for image in tqdm(self.image_list):
  157. self._evaluate_one(image)
  158. def accuracy(self):
  159. print(self.correct.values())
  160. return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
  161. def attrs_accuracy(self):
  162. return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
  163. if __name__ == '__main__':
  164. dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, False)
  165. print(len(dataset))
  166. for d in dataset.image_list:
  167. print(d)
  168. dataset.evaluate()
  169. print(dataset.accuracy())