表格结构模型评估.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. # ---
  2. # jupyter:
  3. # jupytext:
  4. # text_representation:
  5. # extension: .py
  6. # format_name: light
  7. # format_version: '1.5'
  8. # jupytext_version: 1.16.1
  9. # kernelspec:
  10. # display_name: Python 3
  11. # language: python
  12. # name: python3
  13. # ---
  14. # Put these at the top of every notebook, to get automatic reloading and inline plotting
  15. # %reload_ext autoreload
  16. # %autoreload 2
  17. # %matplotlib inline
  18. import sys
  19. sys.path.insert(0, './PaddleOCR/')
  20. # +
  21. from paddleocr import PaddleOCR, PPStructure
  22. from pathlib import Path
  23. # from IPython.core.display import HTML
  24. from IPython.display import display, HTML
  25. from PIL import Image
  26. import re
  27. import cv2
  28. import base64
  29. from io import BytesIO
  30. from PIL import Image, ImageOps
  31. import numpy as np
  32. from matplotlib import pyplot as plt
  33. from sx_utils import *
  34. import json
  35. import threading
  36. from collections import namedtuple
  37. # +
  38. def table_res(img_path, ROTATE=-1):
  39. im = cv2.imread(img_path)
  40. if ROTATE >= 0:
  41. im = cv2.rotate(im, ROTATE)
  42. html = table_engine(im)[0]['res']['html']
  43. return html
  44. def cal_html_to_chs(html):
  45. res = []
  46. rows = re.split('<tr>', html)
  47. for row in rows:
  48. row = re.split('<td>', row)
  49. cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
  50. rec_str = ''.join(cells)
  51. for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
  52. rec_str = rec_str.replace(tag, '')
  53. res.append(rec_str)
  54. rec_res = ''.join(res).replace(' ','')
  55. rec_res = re.split('<tdcolspan="\w+">', rec_res)
  56. rec_res = ''.join(rec_res).replace(' ','')
  57. print(rec_res)
  58. return len(rec_res)
  59. # -
  60. lock = threading.Lock()
  61. def read_annos(path):
  62. with open(path, 'r', encoding='utf-8') as f:
  63. return f.readlines()
  64. from collections import namedtuple
  65. from typing import List
  66. import Levenshtein
  67. from concurrent.futures import ThreadPoolExecutor
  68. from functools import partial
  69. from tqdm import tqdm
  70. import requests
  71. import shutil
  72. # +
  73. from decorator import decorator
  74. @decorator
  75. def rule1_decorator(f, *args,**kwargs):
  76. '''
  77. predict_line = ['项目 ', '', '每100克营养素参考值%', '']
  78. '''
  79. predict_line = args[1]
  80. predict_line = f(*args,**kwargs)
  81. idx = predict_line.index('')
  82. try:
  83. if idx == 1:
  84. if '项目' in predict_line[0] and '每100克' in predict_line[2]:
  85. predict_line[1] = '每100克'
  86. r = re.split('每100克', predict_line[2])
  87. if len(r) == 2 and r[1]:
  88. predict_line[2] = r[1]
  89. except IndexError as e:
  90. print(e)
  91. return predict_line
  92. @decorator
  93. def rule2_decorator(f, *args,**kwargs):
  94. '''
  95. predict_line = ['碳水化合物18.2克', '', '6%', '']
  96. '''
  97. predict_line = args[1]
  98. predict_line = f(*args,**kwargs)
  99. idx = predict_line.index('')
  100. try:
  101. if idx == 1:
  102. if '化合物' in predict_line[0]:
  103. r = re.split('化合物', predict_line[0])
  104. predict_line[0] = '碳水化合物'
  105. if len(r) == 2 and r[1]:
  106. predict_line[1] = r[1]
  107. except IndexError as e:
  108. print(e)
  109. return predict_line
  110. @decorator
  111. def rule3_decorator(f, *args,**kwargs):
  112. '''
  113. ['患直质', '1.6克', '3%', '']
  114. ['脂扇', '1.1', '19%', '']
  115. ['碳水化合物', '勿18.2克', '6%', '']
  116. '''
  117. predict_line = args[1]
  118. predict_line = f(*args,**kwargs)
  119. predict_line = [re.sub('患直质', '蛋白质', s) for s in predict_line]
  120. predict_line = [re.sub('脂扇', '脂肪', s) for s in predict_line]
  121. predict_line = [re.sub('勿(.*克)', '\\1', s) for s in predict_line]
  122. predict_line = [re.sub('毫 克', '毫克', s) for s in predict_line]
  123. return predict_line
  124. # +
  125. WordCompare = namedtuple('WordCompare', ['gt', 'predict', 'cls'])
  126. def get_fn(annos, idx):
  127. return json.loads(annos[idx])['filename']
  128. def get_gt(annos, idx):
  129. return json.loads(annos[idx])['gt']
  130. def get_img_path(annos, idx, root_path):
  131. return root_path / get_fn(annos, idx)
  132. def filter_annos(annos: List[str], fns: List[Path]):
  133. res = []
  134. pattern = '"filename": "(.*.jpg)"'
  135. for anno in annos:
  136. for fn in fns:
  137. m = re.search(fn.name, anno)
  138. if m:
  139. sub = f'"filename": "{str(fn)}"'
  140. new_anno = re.sub(pattern,sub, anno)
  141. res.append(new_anno)
  142. return res
  143. def _predict_table2(annos, idx, root_path):
  144. img_path = get_img_path(annos, idx, root_path=root_path)
  145. try:
  146. lock.acquire()
  147. img_path = str(img_path)
  148. html = table_engine(img_path)[0]['res']['html']
  149. return html
  150. finally:
  151. lock.release()
  152. return None
  153. def _predict_table1(annos, idx, root_path):
  154. img_path = get_img_path(annos, idx, root_path=root_path)
  155. img_path = str(img_path)
  156. img = cv2.imread(img_path)
  157. img_str = cv2bgr_base64(img)
  158. payload = {'image': img_str, 'det': 'conv'}
  159. r = requests.post('http://ocr-table.yili-ocr:8080/ocr_system/table', json=payload)
  160. res = r.json()
  161. if 'status' in res and res['status'] == '000':
  162. return res['result']['html']
  163. else:
  164. return None
  165. def _predict_table3(annos, idx, root_path):
  166. img_path = get_img_path(annos, idx, root_path=root_path)
  167. img_path = str(img_path)
  168. img = cv2.imread(img_path)
  169. img_str = cv2bgr_base64(img)
  170. payload = {'image': img_str, 'det': 'yes', 'prefer_cell': True}
  171. r = requests.post('http://ocr-table.sxkj.com/ocr_system/table', headers={"Content-type": "application/json"}, json=payload)
  172. res = r.json()
  173. if 'status' in res and res['status'] == '000':
  174. return res['result']['html']
  175. else:
  176. return None
  177. def predict_table(annos, idx, root_path):
  178. max_retry = 3
  179. def predict_with_retry(*args, retry=0):
  180. if retry > max_retry:
  181. raise RuntimeError("Max retry failed!")
  182. if retry > 0:
  183. print(f"retry = {retry}, idx = {idx}")
  184. try:
  185. return _predict_table3(*args)
  186. except Exception as e:
  187. print(f'request error: {e}')
  188. return predict_with_retry(*args, retry=retry + 1)
  189. return predict_with_retry(annos, idx, root_path)
  190. class PairLine:
  191. def __init__(self, gt_line, predict_line):
  192. '''
  193. gt_line: ['项目', '每100克', '营养素参考值%', '']
  194. predict_line: ['项目', '', '每100克营养素参考值%', '']
  195. '''
  196. self.gt_line = gt_line
  197. self.predict_line = predict_line
  198. self.result = []
  199. self.cols = 3
  200. def compare(self):
  201. for i, word in enumerate(self.predict_line):
  202. # print(self.predict_line, self.gt_line)
  203. if not self.gt_line:
  204. for j in range(self.cols):
  205. self.result.append(WordCompare(gt='', predict=word, cls=False))
  206. try:
  207. cls = True if word.strip() == self.gt_line[i].strip() else False
  208. if not word and not self.gt_line[i]:
  209. continue
  210. self.result.append(WordCompare(gt=self.gt_line[i], predict=word, cls=cls))
  211. except IndexError:
  212. self.result.append(WordCompare(gt='', predict=word, cls=False))
  213. for i, word in enumerate(self.gt_line):
  214. if not self.predict_line:
  215. for j in range(self.cols):
  216. self.result.append(WordCompare(gt=word, predict='', cls=False))
  217. if i >= len(self.predict_line):
  218. self.result.append(WordCompare(gt=word, predict='', cls=False))
  219. def __repr__(self):
  220. return f'gt_line: {self.gt_line}, predict_line: {self.predict_line}'
  221. # +
  222. class Table:
  223. def __init__(self, fn, gt_html, predict_html):
  224. self.fn = fn
  225. self.gt_html = gt_html
  226. self.predict_html = predict_html
  227. self.format_lines = []
  228. self.pair_lines = []
  229. self.result = self.get_result()
  230. @classmethod
  231. def from_dict(cls, tt):
  232. gt_html = tt['gt_html']
  233. predict_html = tt['predict_html']
  234. fn = Path(tt['fn'])
  235. t = cls(fn, gt_html, predict_html)
  236. return t
  237. def to_dict(self):
  238. return {
  239. 'fn': str(self.fn),
  240. 'gt_html': self.gt_html,
  241. 'predict_html': self.predict_html,
  242. 'pair_lines': [{'gt_line': o.gt_line, 'predict_line': o.predict_line} for o in self.pair_lines],
  243. 'result': [[{'gt': o.gt, 'predict': o.predict, 'cls': o.cls} for o in line] for line in self.result]
  244. }
  245. def display_image(self):
  246. im = Image.open(self.fn)
  247. return ImageOps.exif_transpose(im)
  248. def display_predict_html(self):
  249. return HTML(self.format_predict_html)
  250. def display_gt_html(self):
  251. return HTML(self.gt_html)
  252. @property
  253. def format_predict_html(self):
  254. if self.format_lines:
  255. header = '<html><body><table><tbody>'
  256. footer = '</tbody></table></body></html>'
  257. COLS = 3
  258. html = []
  259. for i, line in enumerate(self.format_lines):
  260. html.append('<tr>')
  261. for j in range(COLS):
  262. try:
  263. if i == 0 and '成分表' in line[j]:
  264. html.append('<td colspan="3">')
  265. html.append(line[j])
  266. html.append('</td>')
  267. break;
  268. else:
  269. html.append('<td>')
  270. html.append(line[j])
  271. html.append('</td>')
  272. except IndexError as e:
  273. print('format_predict_html', e)
  274. html.append('<td>')
  275. html.append('')
  276. html.append('</td>')
  277. continue
  278. html.append('</tr>')
  279. res = f'{header}{"".join(html)}{footer}'
  280. return res
  281. else:
  282. return self.predict_html
  283. @property
  284. def error_rate(self):
  285. corrects = 0
  286. errors = 0
  287. for line in self.result:
  288. for word in line:
  289. if word.cls:
  290. corrects += 1
  291. else:
  292. errors += 1
  293. total = corrects + errors
  294. return 0 if errors == 1 else errors / total
  295. @property
  296. def precision(self):
  297. corrects = 0
  298. p_len = 0
  299. for line in self.result:
  300. for word in line:
  301. if word.cls:
  302. corrects += 1
  303. if word.predict:
  304. p_len += 1
  305. return 0 if p_len == 0 else corrects / p_len
  306. @property
  307. def recall(self):
  308. corrects = 0
  309. g_len = 0
  310. for line in self.result:
  311. for word in line:
  312. if word.cls:
  313. corrects += 1
  314. if word.gt:
  315. g_len += 1
  316. return 0 if g_len == 0 else corrects / g_len
  317. @property
  318. def hmean(self):
  319. total = self.recall + self.precision
  320. return 0 if total == 0 else 2 * self.precision * self.recall / total
  321. def get_result(self):
  322. res = []
  323. self._generate_pair_lines()
  324. # print(self.pair_lines)
  325. for pair_line in self.pair_lines:
  326. pair_line.compare()
  327. res.append(pair_line.result)
  328. return res
  329. @rule3_decorator
  330. @rule2_decorator
  331. @rule1_decorator
  332. def _format_predict_line(self, predict_line):
  333. return predict_line
  334. def _get_lines(self, html) -> List[str]:
  335. '''
  336. res: ['<td>项目</td><td>每100克</td><td>营养素参考值%</td>',...]
  337. '''
  338. if not html:
  339. return []
  340. rows = re.split('<tr>', html)
  341. res = []
  342. for row in rows:
  343. m = re.findall('<td.*>.*</td>', row)
  344. if m:
  345. res.extend(m)
  346. return res
  347. def _generate_pair_lines(self):
  348. gt_lines = self._get_lines(self.gt_html)
  349. predict_lines = self._get_lines(self.predict_html)
  350. gt_words_list = [self._split_to_words(line) for line in gt_lines]
  351. predict_words_list = [self._format_predict_line(self._split_to_words(line)) for line in predict_lines]
  352. self.format_lines.extend(predict_words_list)
  353. DistEntry = namedtuple('DistEntry', ['i', 'j', 'dist'])
  354. dist_entries = []
  355. p = [False] * len(predict_words_list)
  356. g = [False] * len(gt_words_list)
  357. for i, p_line in enumerate(predict_words_list):
  358. for j, g_line in enumerate(gt_words_list):
  359. dist = Levenshtein.distance(''.join(p_line), ''.join(g_line))
  360. dist_entries.append(DistEntry(i=i, j=j, dist=dist))
  361. dist_entries.sort(key=lambda e: e.dist)
  362. for e in dist_entries:
  363. if not p[e.i] and not g[e.j]:
  364. p[e.i] = True
  365. g[e.j] = True
  366. self.pair_lines.append(PairLine(predict_line=predict_words_list[e.i], gt_line=gt_words_list[e.j]))
  367. for i in range(len(p)):
  368. if not p[i]:
  369. self.pair_lines.append(PairLine(predict_line=predict_words_list[i], gt_line=[]))
  370. for i in range(len(g)):
  371. if not g[i]:
  372. self.pair_lines.append(PairLine(predict_line=[], gt_line=gt_words_list[i]))
  373. def _match_gt_line(self, line, gt_lines):
  374. line = ''.join(line)
  375. min_dist = 9999
  376. res = []
  377. for i, gt_line in enumerate(gt_lines):
  378. gt_line = ''.join(gt_line)
  379. dist = Levenshtein.distance(gt_line, line)
  380. if dist < min_dist:
  381. min_dist = dist
  382. res = gt_lines[i]
  383. return res
  384. def _split_to_words(self, line):
  385. '''
  386. line: '<td>项目</td><td>每100克</td><td>营养素参考值%</td>'
  387. res: ['项目', '每100克', '营养素参考值%', '']
  388. '''
  389. res = [re.sub('<td.*>', '', word) for word in re.split('</td>', line)]
  390. return res
  391. def generate_tables(annos, root_path, i):
  392. predict_html = predict_table(annos, i, root_path=root_path)
  393. gt_html = get_gt(annos, i)
  394. fn = get_img_path(annos, i, root_path=root_path)
  395. table = Table(fn, gt_html=gt_html, predict_html=predict_html)
  396. return table
  397. class TableDataset:
  398. def __init__(self, root_path=None, anno_fn=None):
  399. if root_path and anno_fn:
  400. self.tables = []
  401. annos = read_annos(anno_fn)
  402. l = len(annos)
  403. # l = 10
  404. with ThreadPoolExecutor(max_workers=10) as executor:
  405. tables = list(tqdm(executor.map(partial(generate_tables, annos, root_path), range(l)), total=l))
  406. for table in tables:
  407. self.tables.append(table)
  408. else:
  409. self.tables = []
  410. @property
  411. def correct_num(self):
  412. return len(list(filter(lambda x: x.error_rate == 0., self.tables)))
  413. @property
  414. def avg_error_rate(self):
  415. return np.mean([o.error_rate for o in self.tables])
  416. @property
  417. def avg_precision(self):
  418. return np.mean([o.precision for o in self.tables])
  419. @property
  420. def avg_recall(self):
  421. return np.mean([o.recall for o in self.tables])
  422. @property
  423. def avg_hmean(self):
  424. return np.mean([o.hmean for o in self.tables])
  425. def save_hard_cases_for_dataset(self, th, root_path, anno_fn, dst_path):
  426. tables = self.top_error_tables_by_threshold(th)
  427. fns = [t.fn for t in tables]
  428. if not (root_path / dst_path).exists():
  429. (root_path / dst_path).mkdir()
  430. for fn in tqdm(fns):
  431. src = fn
  432. dst = root_path / dst_path / fn.name
  433. shutil.copy2(src, dst)
  434. fns = [dst_path / t.fn.name for t in tables]
  435. annos = read_annos(anno_fn)
  436. annos = filter_annos(annos, fns)
  437. write_annos(annos, root_path / dst_path / 'gt.txt')
  438. def top_error_tables_by_threshold(self, th):
  439. res = []
  440. for r in self.tables:
  441. if r.error_rate >= th:
  442. res.append(r)
  443. return res
  444. def top_error_tables(self, k):
  445. tables = sorted(self.tables, key=lambda x: x.error_rate, reverse=True)
  446. return tables[:k]
  447. def to_json(self, fn):
  448. res = [o.to_dict() for o in self.tables]
  449. with open(fn, 'w', encoding='utf-8') as f:
  450. json.dump(res, f)
  451. @classmethod
  452. def from_json(cls, fn):
  453. res = cls()
  454. with open(fn, 'r') as f:
  455. ts = json.load(f)
  456. for o in ts:
  457. res.tables.append(Table.from_dict(o))
  458. return res
  459. # -
  460. root_path = Path('table-dataset')
  461. anno_fn = root_path / 'merge.txt'
  462. # anno_fn = root_path / 'hardimgs/gt.txt'
  463. # + jupyter={"outputs_hidden": true}
  464. import time
  465. cur_time = int(time.time())
  466. table_ds = TableDataset(root_path, anno_fn)
  467. json_filename = f'table_merge_{cur_time}.json'
  468. table_ds.to_json(json_filename)
  469. print(json_filename)
  470. # +
  471. # table_ds = TableDataset.from_json('table_unconv_request_post_1708499925.json')
  472. # -
  473. table_ds.avg_precision, table_ds.avg_recall, table_ds.avg_hmean, table_ds.correct_num
  474. k = 100
  475. errors = table_ds.top_error_tables(k)
  476. tt = errors[9]
  477. for i, tt in enumerate(errors[:30]):
  478. print(f'id: {i}, precision: {tt.precision}, recall: {tt.recall}')
  479. display(tt.display_image())
  480. display(tt.display_predict_html())
  481. display(tt.display_gt_html())
  482. tt.display_image()
  483. tt.display_predict_html()
  484. tt.display_gt_html()
  485. tt.predict_html
  486. tt.gt_html
  487. tt.fn
  488. len(table_ds.top_error_tables_by_threshold(0.1))
  489. table_ds.save_hard_cases_for_dataset(0.1, root_path, anno_fn, Path('hardimgs'))