zeke-chin 2 years ago
commit
faa03092c9
4 changed files with 364 additions and 0 deletions
  1. 5 0
      .gitignore
  2. 267 0
      new.py
  3. 55 0
      ocr_config.py
  4. 37 0
      use.py

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+.DS_Store
+.idea
+__pycache__/
+tt
+

+ 267 - 0
new.py

@@ -0,0 +1,267 @@
+from pathlib import Path
+from typing import List, Optional
+import cv2
+import requests
+from mdutils.mdutils import MdUtils
+from dataclasses import dataclass
+import json
+import time
+import base64
+from itertools import chain
+from tqdm import tqdm
+import numpy as np
+from ocr_config import OCR_CONFIGS
+
+
+class Image:
+    def __init__(self, path: Path, rotate):
+        self._path = path
+        self.rotate = rotate
+        self._ocr_result = None
+        self.category = True
+        try:
+            self.gt_result = self.get_json()
+        except Exception as e:
+            print(self.json_path)
+            raise e
+
+    def __repr__(self):
+        return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
+
+    @property
+    def path(self):
+        return self._path
+
+    @path.setter
+    def path(self, path):
+        self._path = path
+
+    @property
+    def fn(self):
+        return self._path.stem
+
+    @property
+    def ocr_result(self):
+        return self._ocr_result
+
+    @ocr_result.setter
+    def ocr_result(self, value):
+        self._ocr_result = value
+
+    def get_gt_result(self, key):
+        if key == 'orientation':
+            return self.rotate + 1 if self.rotate is not None else 0
+        elif key in self.gt_result:
+            return self.gt_result[key]
+        else:
+            return None
+
+    @property
+    def json_path(self):
+        return self.path.parent / f'{self.path.stem}.json'
+
+    def save_image(self, img, rotate):
+        dst = self.path.parent.parent / (".ro_dst")
+        if not dst.exists(): dst.mkdir()
+        self.path = dst / f'{self.path.stem}-{rotate + 1}.jpg'
+        # print('save image', self.path)
+        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+        cv2.imwrite(str(self.path), img)
+
+    def get_base64(self, rotate=None):
+        # print(self.path)
+        img = cv2.imread(str(self.path))
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        if rotate is not None:
+            img = cv2.rotate(img, rotate)
+            self.save_image(img, rotate)
+        _, img = cv2.imencode('.jpg', img)
+        img_str = base64.b64encode(img).decode('utf-8')
+        return img_str
+
+    def get_json(self):
+        with open(self.json_path, 'r') as f:
+            return json.load(f)
+
+
+def send_request(image: Image, ocr_name, ocr_address, image_type=None):
+    base64_str = image.get_base64(image.rotate)
+    config = OCR_CONFIGS[ocr_name][ocr_address]
+    headers = {
+        'Content-Type': 'application/json',
+        'Authorization': config.token
+    }
+    data = {
+        'image': base64_str,
+    }
+    if image_type is not None:
+        data['image_type'] = image_type
+    response = requests.post(config.url, headers=headers, json=data)
+    return response.json()
+
+
+def parser_path(path: Path, rotate: bool):
+    name = time.strftime("%m-%d_", time.localtime()) + path.name
+    if rotate:
+        name = f'{name}_R.md'
+    return path.parent / name
+
+
+class Dataset(object):
+    def __init__(self, images_path, image_type, ocr_name, ocr_address, rotate=False):
+        self.image_type = image_type
+        self.ocr_name = ocr_name
+        self.ocr_address = ocr_address
+        self.images_path = images_path
+        self.image_list = []
+        for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
+            if rotate:
+                self.image_list.extend(Image(p, r) for r in [None, 0, 1, 2])
+            else:
+                self.image_list.append(Image(p, None))
+
+        self.field = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
+        # if self.image_type:
+        #     self.field = ['orientation', 'type', 'address', 'address_province', 'address_city', 'address_region',
+        #                   'address_detail']
+        # else:
+        #     self.field = ['orientation', 'name', 'id', 'gender', 'birthplace', 'birthplace_province', 'birthplace_city',
+        #                   'birthplace_region', 'native_place', 'native_place_province', 'native_place_city',
+        #                   'native_place_region', 'blood_type', 'religion']
+
+        self.correct = {k: 0 for k in self.field}
+        self.error = {k: 0 for k in self.field}
+
+    def __len__(self):
+        return len(self.image_list)
+
+    def _evaluate_one(self, image: Image):
+        def _get_predict(r, key):
+            if isinstance(r[key], dict):
+                return r[key]['text']
+            else:
+                return r[key]
+
+        if image.rotate is not None: image.gt_result['orientation'] = image.rotate + 1
+        r = send_request(image, self.ocr_name, self.ocr_address, self.image_type)
+        err_str = ''
+        if r['status'] == '000':
+            res = r['result']
+            for key in self.field:
+                # print('attr: ', key)
+                if key in res:
+                    gt = image.get_gt_result(key)
+                    predict = _get_predict(res, key)
+                    # print(f'gt: {gt}, predict: {predict}')
+                    if predict == gt:
+                        self.correct[key] += 1
+                    else:
+                        image.category = False
+                        self.error[key] += 1
+                        err_str += f'正确:{gt}<br>返回:{predict}<br>'
+            if image.category:
+                image.ocr_result = image.gt_result
+            else:
+                image.ocr_result = err_str
+        else:
+            image.ocr_result = r['msg']
+            image.category = False
+            for key in self.field:
+                self.error[key] += 1
+
+    def __call__(self):  # sourcery skip: yield-from
+        # yield 返回一个生成器
+        for image in self.image_list:
+            yield image
+
+    # 比较
+    def evaluate(self):
+        for image in tqdm(self.image_list):
+            self._evaluate_one(image)
+
+    # 计算总体准确度
+    @property
+    def accuracy(self):
+        return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
+
+    # 计算元素准确度
+    @property
+    def attrs_accuracy(self):
+        return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.field}
+
+
+class MD(object):
+    def __init__(self, file_path: Path):
+        self.name = file_path.name
+        self.f = MdUtils(file_name=str(file_path))
+        self.field_table: List = ['字段', '正确率']
+        self.true_table: List = ['图片', '识别结果']
+        self.false_table: List = ['图片', '识别结果']
+        self.write_header(f'{self.name}测试报告')
+
+    def write_header(self, title, level=1):
+        self.f.new_header(level=level, title=title)
+
+    def write_total_accuracy(self, ds: Dataset):
+        def get_format_total_accuracy(ds: Dataset):
+            acc = ds.accuracy * 100
+            return "{:.2f}%".format(acc)
+
+        # 1. 拿到format之后的百分数
+        res = get_format_total_accuracy(ds)
+
+        # 2. 写入
+        self.f.new_paragraph(res)
+
+    def write_table_accuracy(self, ds: Dataset, columns=2, text_align='center'):
+        def format_table_accuracy(ds: Dataset):
+            table = ds.attrs_accuracy
+            for k, v in table.items():
+                acc = v * 100
+                table[k] = "{:.2f}%".format(acc)
+            return table
+
+        def dict_2_list(dic: dict):
+            l = []
+            for k, v in dic.items():
+                l.extend((k, v))
+            return l
+
+        table_dict = format_table_accuracy(ds)
+        table_list = dict_2_list(table_dict)
+        self.field_table.extend(table_list)
+
+        rows = len(self.field_table) // columns
+        self.f.new_table(columns=columns, rows=rows, text=self.field_table, text_align=text_align)
+
+    def write_table_result(self, ds: Dataset, columns=2, text_align='center'):
+        for image in ds.image_list:
+            md_image = self.f.new_inline_image(text='', path=str(image.path))
+            if image.category:
+                self.true_table.extend([md_image, image.ocr_result])
+            else:
+                self.false_table.extend([md_image, image.ocr_result])
+
+        true_rows = len(self.true_table) // columns
+        false_rows = len(self.false_table) // columns
+        self.write_header('True')
+        self.f.new_table(columns=columns, rows=true_rows, text=self.true_table, text_align=text_align)
+        self.write_header('False')
+        self.f.new_table(columns=columns, rows=false_rows, text=self.false_table, text_align=text_align)
+
+# if __name__ == '__main__':
+#     markdown = MD('英语等级证书')
+#
+#     dataset = Dataset(Path(''), 'cet', 'local', False)
+#     print(len(dataset))
+#     for d in dataset():
+#         print(d)
+#
+#     dataset.evaluate()
+#     print(dataset.accuracy)
+#
+#     markdown.write_total_accuracy(dataset)
+#     markdown.write_table_accuracy(dataset)
+#     markdown.write_table_result(dataset)
+#
+#     markdown.f.create_md_file()

+ 55 - 0
ocr_config.py

@@ -0,0 +1,55 @@
+from dataclasses import dataclass
+from typing import List
+
+
+@dataclass
+class Type:
+    image_type: int
+    image_field: List
+
+
+@dataclass
+class RequestConfig:
+    url: str
+    token: str
+
+@dataclass
+class Configs:
+    request: RequestConfig
+    type: Type
+
+
+
+cet_local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
+cet_test_config = RequestConfig(
+    url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
+    token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
+cet_sb_config = RequestConfig(
+    url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
+    token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
+
+CET_CONFIGS = {
+    'local': cet_local_config,
+    'test': cet_test_config,
+    'sb': cet_sb_config
+}
+
+# regbook
+regbook_local_config = RequestConfig(url='http://192.168.199.249:18020/ocr_system/regbook', token='')
+regbook_test_config = RequestConfig(
+    url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/hkbsb/regbook',
+    token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
+regbook_sb_config = RequestConfig(
+    url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/hkbsb/regbook',
+    token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
+
+REGBOOK_CONFIGS = {
+    'local': regbook_local_config,
+    'test': regbook_test_config,
+    'sb': regbook_sb_config
+}
+
+OCR_CONFIGS = {
+    'cet': CET_CONFIGS,
+    'regbook': REGBOOK_CONFIGS
+}

+ 37 - 0
use.py

@@ -0,0 +1,37 @@
+from pathlib import Path
+import time
+
+from new import MD, Image, Dataset, parser_path
+
+# config
+# 图片路径
+image_path = Path('/Users/zeke/work/sx/OCR/HROCR/hr-ocr-cet/markdown/md/tt/img')
+image_type = None
+# 是否旋转
+image_rotate = True
+ocr_address = 'local'  # 'local' 'test' 'sb'
+ocr_name = 'cet'  # 'cet' 'idcard' 'bankcard' 'regbook' 'schoolcert'
+md_name = 'CET-tem'
+md_path = '/Users/zeke/work/sx/OCR/HROCR/hr-ocr-cet/markdown/md/tt'
+md_file = parser_path(Path(md_path) / Path(md_name), image_rotate)
+
+
+
+
+if __name__ == '__main__':
+    markdown = MD(md_file)
+
+    dataset = Dataset(image_path, image_type, ocr_name, ocr_address, image_rotate)
+    print(len(dataset))
+    for d in dataset():
+        print(d)
+
+    dataset.evaluate()
+    print(dataset.accuracy)
+
+    markdown.write_total_accuracy(dataset)
+    markdown.write_table_accuracy(dataset)
+    markdown.write_table_result(dataset)
+
+    print(md_file)
+    markdown.f.create_md_file()