Bladeren bron

add dataset

Zhang Li 2 jaren geleden
bovenliggende
commit
dca8c6300a
2 gewijzigde bestanden met toevoegingen van 172 en 113 verwijderingen
  1. 0 0
      markdown/__init__.py
  2. 172 113
      markdown/new.py

+ 0 - 0
markdown/__init__.py


+ 172 - 113
markdown/new.py

@@ -1,9 +1,5 @@
-import operator
 from pathlib import Path
-from typing import List
-
-import numpy as np
-from mdutils.mdutils import MdUtils
+from typing import List, Optional
 import cv2
 import requests
 from dataclasses import dataclass
@@ -11,148 +7,211 @@ import json
 import time
 import base64
 from itertools import chain
+from tqdm import tqdm
 
+@dataclass
+class RequestConfig:
+    url: str
+    token: str
 
-class MarkdownTable(object):
-    def __init__(self, name):
-        self.name = name
-        self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
-        self.field_table = ['字段', '正确率']
-        self.true_table = ['图片', '识别结果']
-        self.false_table = ['图片', '识别结果']
 
-    def add_field_table(self, fields: List):
-        self.field_table.extend(fields)
 
-    def add_true_table(self, image_and_field: List):
-        self.true_table.extend(image_and_field)
+local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
+test_config = RequestConfig(url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
+sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet', token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
 
-    def add_false_table(self, image_and_field: List):
-        self.false_table.extend(image_and_field)
+CONFIGS = {
+    'local': local_config,
+    'test': test_config,
+    'sb': sb_config
+}
 
+CONFIG_STR = 'local'
 
-@dataclass
-class Image:
-    path: Path
-    rotate: int
+IMAGE_TYPE = None
 
+IMAGE_PATH = Path('images/cet6/')
 
-    @property
-    def fn(self):
-        return self.path.stem
 
-    @property
-    def json_path(self):
-        return self.path.parent / f'{self.path.stem}.json'
+# class MarkdownTable(object):
+#     def __init__(self, name):
+#         self.name = name
+#         self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
+#         self.field_table = ['字段', '正确率']
+#         self.true_table = ['图片', '识别结果']
+#         self.false_table = ['图片', '识别结果']
 
-    def get_base64(self, rotate=0):
-        return 'dsf'
+#     def add_field_table(self, fields: List):
+#         self.field_table.extend(fields)
 
+#     def add_true_table(self, image_and_field: List):
+#         self.true_table.extend(image_and_field)
 
+#     def add_false_table(self, image_and_field: List):
+#         self.false_table.extend(image_and_field)
+
+
+class Image:
+    def __init__(self, path: Path, rotate):
+        self._path = path
+        self.rotate = rotate
+        self._ocr_result = None
+        self.cate = True
+        try:
+            self.gt_result = self.get_json()
+        except Exception as e:
+            print(self.json_path)
+            raise e
+
+    def __repr__(self):
+        return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
+
+    @property
+    def path(self):
+        return self._path
+
+    @path.setter
+    def path(self, path):
+        self._path = path
+
+    @property
+    def fn(self):
+        return self._path.stem
 
+    @property
+    def ocr_result(self):
+        return self._ocr_result
 
+    @ocr_result.setter
+    def ocr_result(self, value):
+        self._ocr_result = value
 
+    def get_gt_result(self, key):
+        if key in self.gt_result:
+            return self.gt_result[key]
+        else:
+            return None
 
+    @property
+    def json_path(self):
+        return self.path.parent / f'{self.path.stem}.json'
 
-class DataSet(object):
-    def __init__(self, image_path, rotate=False):
+    def save_image(self, img, rotate=None):
+        dire = self.path.parent / (".ro_dire")
+        if not dire.exists(): dire.mkdir()
+        self.path = dire / f'{self.path.stem}-{rotate+1}.jpg'
+        cv2.imwrite(self.path, img)
+
+    def get_base64(self, rotate=None):
+        print(self.path)
+        img = cv2.imread(str(self.path))
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        if rotate is not None:
+            img = cv2.rotate(img, rotate)
+            self.save_image(img)
+        _, img = cv2.imencode('.jpg', img)
+        img_str = base64.b64encode(img).decode('utf-8')
+        return img_str
+
+    def get_json(self):
+        with open(self.json_path, 'r') as f:
+            return json.load(f)
+
+
+def send_request(base64_str, config_str, image_type=None):
+    config = CONFIGS[config_str]
+    headers = {
+        'Content-Type': 'application/json',
+        'Authorization': config.token
+    }
+    data = {
+        'image': base64_str,
+    }
+    if image_type:
+        data['image_type'] = image_type
+    response = requests.post(config.url, headers=headers, json=data)
+    return response.json()
+
+
+class Dataset(object):
+    def __init__(self, image_path, image_type, config_str, rotate=False):
+        self.image_type = image_type
+        self.config_str = config_str
         self.image_path = image_path
         self.image_list = []
         for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
-            self.image_list.append(Image(p, 0))
+            if rotate:
+                for rotate in [None, 0, 1, 2]:
+                    self.image_list.append(Image(p, rotate))
+            else:
+                self.image_list.append(Image(p, None))
+
 
         self.attrs = ['orientation','name','id','language', 'level', 'exam_time', 'score']
 
-        self.tp = {k: 0 for k in self.attrs}
+        self.correct = {k: 0 for k in self.attrs}
+        self.error = {k: 0 for k in self.attrs}
 
 
-        # self.field_rate = {
-        #     'orientation': self.count,
-        #     'name': self.count,
-        #     'id': self.count,
-        #     'language': self.count,
-        #     'level': self.count,
-        #     'exam_time': self.count,
-        #     'score': self.count,
-        # }
-        # self.del_field = del_field
+    def __len__(self):
+        return len(self.image_list)
 
 
     def _evaluate_one(self, image: Image):
+        def _get_predict(r, key):
+            if isinstance(r[key], dict):
+                return r[key]['text']
+            else:
+                return r[key]
 
 
+        base64_str = image.get_base64()
+        r = send_request(base64_str, self.config_str, self.image_type)
+        print(r)
+        err_str = ''
 
-    @property
-    def count(self):
-        return len(list(chain(*[self.image_paths.rglob('*.jpg')]))) * 4 if self.is_rotate else 1
+        if r['status'] == '000':
+            res = r['result']
+            for key in self.attrs:
+                if key in res:
+                    gt = image.get_gt_result(key)
+                    predict = _get_predict(res, key)
+                    print(f'gt: {gt}, predict: {predict}')
+                    if predict == gt:
+                        self.correct[key] += 1
+                    else:
+                        image.cate = False
+                        self.error[key] += 1
+                        err_str += f'正确:{gt}<br>返回:{predict}<br>'
+            if image.cate:
+                image.ocr_result = r['result']
+            else:
+                image.ocr_result = err_str
+        else:
+            image.ocr_result = r['msg']
+            image.cate = False
+            for key in self.attrs:
+                self.error[key] += 1
 
-    @property
-    def images(self):
-        images_path_list = list(chain(*[Path(self.image_paths).rglob('*.jpg')]))
-        images_path_dict = {path.name: [str(path), str(path.parent / f'{path.stem}.json')] for path in images_path_list}
-        return {i: images_path_dict[i] for i in sorted(images_path_dict)}
+    def evaluate(self):
+        for image in tqdm(self.image_list):
+            self._evaluate_one(image)
 
-    def revise_field_rate(self, field):
-        self.field_rate[field] = self.field_rate[field] - 1
 
-    @property
-    def field_rate_2_list(self):
-        table = []
-        for k, v in self.field_rate.items():
-            table.extend((k, "{:.2f}%".format(v / self.count * 100)))
-        return table
-
-    def image_2_base64(self, image):
-        dire = self.image_paths.parent / (".ro_dire")
-        if not dire.exists(): dire.mkdir()
+    def accuracy(self):
+        print(self.correct.values())
+        return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
 
-        image_path = Path(self.images[image][0])
-
-        if self.is_rotate:
-            img = cv2.imread(str(image_path))
-            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-            base64_imgs = []
-            for rotate in {cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE}:
-                img_rotated = cv2.rotate(img, rotate)
-                img_rotated = cv2.cvtColor(img_rotated, cv2.COLOR_BGR2RGB)
-                img_rotated_path = dire / f"{image_path.stem}_{str(rotate + 1)}.jpg"
-                cv2.imread(str(img_rotated_path), img_rotated)
-                with img_rotated_path.open('rb') as f:
-                    img_str: str = base64.encodebytes(f.read()).decode('utf-8')
-                    base64_imgs.extend(img_str)
-            return base64_imgs
-        else:
-            with image_path.open('rb') as f:
-                img_str: str = base64.encodebytes(f.read()).decode('utf-8')
-                return [img_str]
 
-    def res_2_dict(self, r):
-        if r['status'] == '000':
-            r = r['result']
-            if r:
-                del r['confidence']
-            if self.del_field is not None: del r[self.del_field]
-            return {k: v['text'] if isinstance(v, dict) else v for k, v in r.items()}
-        elif r['status'] == '101':
-            return r['msg']
-
-    def json_2_dict(self, image):
-        json_path = Path(self.images[image][1])
-        with json_path.open('r') as f:
-            json_dict = json.load(f)
-            if self.del_field is not None: del json_dict[self.del_field]
-            return json_dict
-
-    def compare_dict(self, MT: MarkdownTable, res_dict, json_dict, image_path):
-        image_mark = MT.mdFile.new_inline_image(text='', path=image_path)
-        if operator.eq(res_dict, json_dict):
-            MT.add_true_table([image_mark, json_dict])
-        elif type(res_dict) == dict:
-            err_str = ""
-            for key in res_dict:
-                if res_dict[key] != res_dict[key]:
-                    err_str = f"{err_str}正确:{res_dict[key]}<br>返回:{res_dict[key]}<br>"
-                    self.revise_field_rate(key)
-            MT.add_false_table(([image_mark, err_str]))
-        elif type(res_dict) == str:
-            MT.add_false_table([image_mark, res_dict])
+    def attrs_accuracy(self):
+        return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
+
+
+if __name__ == '__main__':
+    dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, False)
+    print(len(dataset))
+    for d in dataset.image_list:
+        print(d)
+
+    dataset.evaluate()
+    print(dataset.accuracy())