new.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from pathlib import Path
  2. from typing import List, Optional
  3. import cv2
  4. import requests
  5. from mdutils.mdutils import MdUtils
  6. from dataclasses import dataclass
  7. import json
  8. import time
  9. import base64
  10. from itertools import chain
  11. from tqdm import tqdm
  12. from ocr_config import OCR_CONFIGS, Filed
  13. class Image:
  14. def __init__(self, path: Path, rotate, is_rotate):
  15. self._path = path
  16. self.rotate = rotate
  17. self._ocr_result = None
  18. self.category = True
  19. self.is_rotate = is_rotate
  20. try:
  21. self.gt_result = self.get_json()
  22. except Exception as e:
  23. print(self.json_path)
  24. raise e
  25. def __repr__(self):
  26. return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
  27. # 将方法转换为相同名称的只读属性
  28. @property
  29. def path(self):
  30. return self._path
  31. @path.setter
  32. def path(self, path):
  33. self._path = path
  34. @property
  35. def fn(self):
  36. return self._path.stem
  37. @property
  38. def ocr_result(self):
  39. return self._ocr_result
  40. @ocr_result.setter
  41. def ocr_result(self, value):
  42. self._ocr_result = value
  43. def get_gt_result(self, key):# sourcery skip: merge-duplicate-blocks, remove-redundant-if
  44. if key == 'orientation':
  45. if self.is_rotate:
  46. return self.rotate + 1 if self.rotate is not None else 0
  47. else:
  48. return self.gt_result[key]
  49. elif key in self.gt_result:
  50. return self.gt_result[key]
  51. else:
  52. return None
  53. @property
  54. def json_path(self):
  55. return self.path.parent / f'{self.path.stem}.json'
  56. def save_image(self, img, rotate):
  57. dst = self.path.parent.parent / (".ro_dst")
  58. if not dst.exists(): dst.mkdir()
  59. self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg'
  60. # print('save image', self.path)
  61. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  62. cv2.imwrite(str(self.path), img)
  63. return self.path
  64. def get_base64(self, rotate=None):
  65. # print(self.path)
  66. img = cv2.imread(str(self.path))
  67. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  68. path = self.path
  69. if rotate is not None:
  70. img = cv2.rotate(img, rotate)
  71. path = self.save_image(img, rotate)
  72. # imencode 将图片编码到缓存,并保存到本地
  73. with open(path, 'rb') as f:
  74. return base64.encodebytes(f.read()).decode('utf-8')
  75. def get_json(self):
  76. with open(self.json_path, 'r') as f:
  77. return json.load(f)
  78. def send_request(image: Image, ocr_name, ocr_address, image_type=None):
  79. base64_str = image.get_base64(image.rotate)
  80. config = OCR_CONFIGS[ocr_name][ocr_address]
  81. headers = {
  82. 'Content-Type': 'application/json',
  83. 'Authorization': config.token
  84. }
  85. data = {
  86. 'image': base64_str,
  87. }
  88. if image_type is not None:
  89. data['image_type'] = image_type
  90. response = requests.post(config.url, headers=headers, json=data)
  91. return response.json()
  92. def parser_path(path: Path, rotate: bool):
  93. name = time.strftime("%m-%d_", time.localtime()) + path.name
  94. if rotate:
  95. name = f'{name}_R.md'
  96. return path.parent / name
  97. class Dataset(object):
  98. def __init__(self, images_path, image_type, ocr_name, ocr_address, field, rotate=False):
  99. self.image_type = image_type
  100. self.ocr_name = ocr_name
  101. self.ocr_address = ocr_address
  102. self.images_path = images_path
  103. self.image_list = []
  104. # chain 迭代器,首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素
  105. # eg:chain('ABC', 'DEF') --> A B C D E F
  106. for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
  107. if rotate:
  108. self.image_list.extend(Image(p, r, rotate) for r in [None, 0, 1, 2])
  109. else:
  110. self.image_list.append(Image(p, None, rotate))
  111. self.field = Filed.get(field)
  112. self.correct = {k: 0 for k in self.field}
  113. self.error = {k: 0 for k in self.field}
  114. def __len__(self):
  115. return len(self.image_list)
  116. def _evaluate_one(self, image: Image):
  117. def _get_predict(r, key):
  118. # isinstance() 函数来判断一个对象是否是一个已知的类型
  119. if isinstance(r[key], dict):
  120. return r[key]['text']
  121. else:
  122. return r[key]
  123. if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1
  124. r = send_request(image, self.ocr_name, self.ocr_address, self.image_type)
  125. err_str = ''
  126. if r['status'] == '000':
  127. res = r['result']
  128. for key in self.field:
  129. # print('attr: ', key)
  130. if key in res:
  131. gt = image.get_gt_result(key)
  132. predict = _get_predict(res, key)
  133. # print(f'gt: {gt}, predict: {predict}')
  134. if predict == gt:
  135. self.correct[key] += 1
  136. else:
  137. image.category = False
  138. self.error[key] += 1
  139. err_str += f'-------{key}-------<br>正确:{gt}<br>返回:{predict}<br>'
  140. if image.category:
  141. image.ocr_result = image.gt_result
  142. else:
  143. image.ocr_result = err_str
  144. else:
  145. image.ocr_result = r['msg']
  146. image.category = False
  147. for key in self.field:
  148. self.error[key] += 1
  149. def __call__(self): # sourcery skip: yield-from
  150. # yield 返回一个生成器
  151. for image in self.image_list:
  152. yield image
  153. # 比较
  154. def evaluate(self):
  155. for image in tqdm(self.image_list):
  156. self._evaluate_one(image)
  157. # 计算总体准确度
  158. @property
  159. def accuracy(self):
  160. return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
  161. # 计算元素准确度
  162. @property
  163. def attrs_accuracy(self):
  164. return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field}
  165. class MD(object):
  166. def __init__(self, file_path: Path):
  167. self.name = file_path.name
  168. self.f = MdUtils(file_name=str(file_path))
  169. self.field_table: List = ['字段', '正确率']
  170. self.true_table: List = ['图片', '识别结果']
  171. self.false_table: List = ['图片', '识别结果']
  172. self.write_header(f'{self.name}测试报告')
  173. def write_header(self, title, level=1):
  174. self.f.new_header(level=level, title=title)
  175. def write_total_accuracy(self, ds: Dataset):
  176. def get_format_total_accuracy(ds: Dataset):
  177. acc = ds.accuracy * 100
  178. return "{:.2f}%".format(acc)
  179. # 1. 拿到format之后的百分数
  180. res = get_format_total_accuracy(ds)
  181. # 2. 写入
  182. self.f.new_paragraph(res)
  183. def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'):
  184. def format_table_accuracy(ds: Dataset):
  185. table = ds.attrs_accuracy
  186. for k, v in table.items():
  187. acc = v * 100
  188. table[k] = "{:.2f}%".format(acc)
  189. return table
  190. def dict_2_list(dic: dict):
  191. l = []
  192. for k, v in dic.items():
  193. l.extend((k, v))
  194. return l
  195. table_dict = format_table_accuracy(ds)
  196. table_list = dict_2_list(table_dict)
  197. self.field_table.extend(table_list)
  198. rows = len(self.field_table) // columns
  199. self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align)
  200. def write_table_result(self, ds: Dataset, columns=2, text_align='center'):
  201. for image in ds.image_list:
  202. md_image = self.f.new_inline_image(text='', path=f'{image.path.parent.name}/{image.path.name}')
  203. if image.category:
  204. self.true_table.extend([md_image, image.ocr_result])
  205. else:
  206. self.false_table.extend([md_image, image.ocr_result])
  207. true_rows = len(self.true_table) // columns
  208. false_rows = len(self.false_table) // columns
  209. self.write_header('True')
  210. self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align)
  211. self.write_header('False')
  212. self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align='left')
  213. # if __name__ == '__main__':
  214. # markdown = MD('英语等级证书')
  215. #
  216. # dataset = Dataset(Path(''), 'cet', 'local', False)
  217. # print(len(dataset))
  218. # for d in dataset():
  219. # print(d)
  220. #
  221. # dataset.evaluate()
  222. # print(dataset.accuracy)
  223. #
  224. # markdown.write_total_accuracy(dataset)
  225. # markdown.write_table_accuracy(dataset)
  226. # markdown.write_table_result(dataset)
  227. #
  228. # markdown.f.create_md_file()