|
@@ -2,18 +2,20 @@ import copy
|
|
import re
|
|
import re
|
|
from itertools import chain
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pandas as pd
|
|
import json
|
|
import json
|
|
from mdutils.mdutils import MdUtils
|
|
from mdutils.mdutils import MdUtils
|
|
import requests
|
|
import requests
|
|
-
|
|
|
|
|
|
+import html2text
|
|
from YQ_OCR.config import keyDict
|
|
from YQ_OCR.config import keyDict
|
|
|
|
+from YQ_OCR.to_md.datasets import Dataset
|
|
|
|
+from YQ_OCR.to_md.text2md import TableMD
|
|
|
|
|
|
url = 'http://192.168.199.107:18087'
|
|
url = 'http://192.168.199.107:18087'
|
|
url_path = '/ocr_system/identify'
|
|
url_path = '/ocr_system/identify'
|
|
-imgs_path = '/Users/sxkj/to_md/YQ_OCR/img'
|
|
|
|
|
|
+# imgs_path = '/Users/sxkj/to_md/YQ_OCR/img'
|
|
|
|
+imgs_path = '../img'
|
|
|
|
|
|
|
|
|
|
# 1. xlsx -> 正确json文件(写入厂家信息)
|
|
# 1. xlsx -> 正确json文件(写入厂家信息)
|
|
@@ -48,6 +50,7 @@ def _parse_result(r): # sourcery skip: dict-comprehension
|
|
res[field] = result[field]
|
|
res[field] = result[field]
|
|
res['noKeyList'] = result['noKeyList']
|
|
res['noKeyList'] = result['noKeyList']
|
|
res['logoList'] = result['logoList']
|
|
res['logoList'] = result['logoList']
|
|
|
|
+ res['tableList'] = result['tableList']
|
|
logoFileName = [log['logoFileName'] for log in res['logoList']]
|
|
logoFileName = [log['logoFileName'] for log in res['logoList']]
|
|
res['logoList'] = logoFileName
|
|
res['logoList'] = logoFileName
|
|
return res
|
|
return res
|
|
@@ -78,7 +81,8 @@ def evaluate_one(xlsx_dict, res_dict):
|
|
for key_no_xlsx_no_space, key_no_xlsx in zip(xlsx_dict_no_space['noKeyList'], xlsx_dict['noKeyList']):
|
|
for key_no_xlsx_no_space, key_no_xlsx in zip(xlsx_dict_no_space['noKeyList'], xlsx_dict['noKeyList']):
|
|
key_no_dict[key_no_xlsx_no_space] = []
|
|
key_no_dict[key_no_xlsx_no_space] = []
|
|
for key_no_res in res_dict['noKeyList']:
|
|
for key_no_res in res_dict['noKeyList']:
|
|
- key_no_dict[key_no_xlsx_no_space].append((Levenshtein_Distance(key_no_xlsx_no_space, key_no_res), key_no_res))
|
|
|
|
|
|
+ key_no_dict[key_no_xlsx_no_space].append(
|
|
|
|
+ (Levenshtein_Distance(key_no_xlsx_no_space, key_no_res), key_no_res))
|
|
sort_NoKey = sorted(key_no_dict[key_no_xlsx_no_space], key=lambda x: x[0])
|
|
sort_NoKey = sorted(key_no_dict[key_no_xlsx_no_space], key=lambda x: x[0])
|
|
NoKey_min_distance = sort_NoKey[0][0]
|
|
NoKey_min_distance = sort_NoKey[0][0]
|
|
if NoKey_min_distance == 0:
|
|
if NoKey_min_distance == 0:
|
|
@@ -127,21 +131,23 @@ def evaluate_one(xlsx_dict, res_dict):
|
|
|
|
|
|
# 打开正确的json文件
|
|
# 打开正确的json文件
|
|
def open_true_json(j_path):
|
|
def open_true_json(j_path):
|
|
- with j_path.open('r') as f:
|
|
|
|
|
|
+ with j_path.open('r', encoding='utf-8') as f:
|
|
j_dict = json.load(f)
|
|
j_dict = json.load(f)
|
|
j_json_str = json.dumps(j_dict, ensure_ascii=False)
|
|
j_json_str = json.dumps(j_dict, ensure_ascii=False)
|
|
return j_dict, j_json_str
|
|
return j_dict, j_json_str
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- img_paths = chain(*[Path(imgs_path).rglob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'PNG', 'JPG', 'JPEG']])
|
|
|
|
|
|
+ img_paths = chain(*[Path(imgs_path).rglob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg']])
|
|
all_rate = []
|
|
all_rate = []
|
|
|
|
+ table_mean_acc = []
|
|
for img_path in img_paths:
|
|
for img_path in img_paths:
|
|
print(img_path)
|
|
print(img_path)
|
|
# json result
|
|
# json result
|
|
true_d, true_json = open_true_json(img_path.with_suffix('.json'))
|
|
true_d, true_json = open_true_json(img_path.with_suffix('.json'))
|
|
result = send_request(img_path, true_json)
|
|
result = send_request(img_path, true_json)
|
|
res_d = _parse_result(result)
|
|
res_d = _parse_result(result)
|
|
|
|
+
|
|
# md
|
|
# md
|
|
md_file_path = img_path.parent / (img_path.with_suffix('.md'))
|
|
md_file_path = img_path.parent / (img_path.with_suffix('.md'))
|
|
MD = MdUtils(file_name=str(md_file_path))
|
|
MD = MdUtils(file_name=str(md_file_path))
|
|
@@ -150,10 +156,22 @@ if __name__ == '__main__':
|
|
MD.new_header(level=1, title='测试结果')
|
|
MD.new_header(level=1, title='测试结果')
|
|
MD.new_header(level=2, title=f'正确率:{rate}')
|
|
MD.new_header(level=2, title=f'正确率:{rate}')
|
|
MD.new_header(level=3, title=statistics)
|
|
MD.new_header(level=3, title=statistics)
|
|
- print(f'正确率:{rate}')
|
|
|
|
|
|
+ print(f'文字识别正确率:{rate}')
|
|
MD.new_table(columns=4, rows=len(table_result) // 4, text=table_result, text_align='center')
|
|
MD.new_table(columns=4, rows=len(table_result) // 4, text=table_result, text_align='center')
|
|
MD.create_md_file()
|
|
MD.create_md_file()
|
|
|
|
|
|
- print('-------------------------------')
|
|
|
|
|
|
+ # table gt result
|
|
|
|
+ markdown = TableMD(img_path.name)
|
|
|
|
+ dataset = Dataset(gt_file=img_path.with_suffix('.txt'), img_name=img_path.name, results=res_d)
|
|
|
|
+ markdown.write_header(title='推理结果', level=2)
|
|
|
|
+ markdown.write_table_accuracy(ds=dataset, key='new')
|
|
|
|
+ table_acc = markdown.get_table_accuracy()
|
|
|
|
+ table_mean_acc.append(table_acc)
|
|
|
|
+ print(f'表格识别正确率:{table_acc:.2f}%')
|
|
|
|
+ markdown.f.create_md_file()
|
|
|
|
+
|
|
|
|
+ print('----------------------------------------')
|
|
all_rate = "{:.2f}%".format(np.mean(all_rate) * 100)
|
|
all_rate = "{:.2f}%".format(np.mean(all_rate) * 100)
|
|
- print(f'总体正确率:{all_rate}')
|
|
|
|
|
|
+ all_table_rate = "{:.2f}%".format(np.mean(table_mean_acc))
|
|
|
|
+ print(f'文字识别总体正确率:{all_rate}')
|
|
|
|
+ print(f'表格识别总体正确率:{all_table_rate}')
|