Browse Source

修复 旋转bug 修复文件读取bug(cv2 -> with.open) & 添加 不同ocr_name对应字段

zeke-chin 2 years ago
parent
commit
e7b714910a
3 changed files with 63 additions and 28 deletions
  1. 23 19
      new.py
  2. 34 2
      ocr_config.py
  3. 6 7
      use.py

+ 23 - 19
new.py

@@ -9,16 +9,16 @@ import time
 import base64
 from itertools import chain
 from tqdm import tqdm
-import numpy as np
-from ocr_config import OCR_CONFIGS
+from ocr_config import OCR_CONFIGS, Filed
 
 
 class Image:
-    def __init__(self, path: Path, rotate):
+    def __init__(self, path: Path, rotate, is_rotate):
         self._path = path
         self.rotate = rotate
         self._ocr_result = None
         self.category = True
+        self.is_rotate = is_rotate
         try:
             self.gt_result = self.get_json()
         except Exception as e:
@@ -28,6 +28,7 @@ class Image:
     def __repr__(self):
         return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.category}'
 
+    # 将方法转换为相同名称的只读属性
     @property
     def path(self):
         return self._path
@@ -48,9 +49,12 @@ class Image:
     def ocr_result(self, value):
         self._ocr_result = value
 
-    def get_gt_result(self, key):
+    def get_gt_result(self, key):# sourcery skip: merge-duplicate-blocks, remove-redundant-if
         if key == 'orientation':
-            return self.rotate + 1 if self.rotate is not None else 0
+            if self.is_rotate:
+                return self.rotate + 1 if self.rotate is not None else 0
+            else:
+                return self.gt_result[key]
         elif key in self.gt_result:
             return self.gt_result[key]
         else:
@@ -67,16 +71,19 @@ class Image:
         # print('save image', self.path)
         img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
         cv2.imwrite(str(self.path), img)
+        return self.path
 
     def get_base64(self, rotate=None):
         # print(self.path)
         img = cv2.imread(str(self.path))
         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        path = self.path
         if rotate is not None:
             img = cv2.rotate(img, rotate)
-            self.save_image(img, rotate)
-        _, img = cv2.imencode('.jpg', img)
-        return base64.b64encode(img).decode('utf-8')
+            path = self.save_image(img, rotate)
+            # imencode 将图片编码到缓存,并保存到本地
+        with open(path, 'rb') as f:
+            return base64.encodebytes(f.read()).decode('utf-8')
 
     def get_json(self):
         with open(self.json_path, 'r') as f:
@@ -107,26 +114,22 @@ def parser_path(path: Path, rotate: bool):
 
 
 class Dataset(object):
-    def __init__(self, images_path, image_type, ocr_name, ocr_address, rotate=False):
+    def __init__(self, images_path, image_type, ocr_name, ocr_address, field, rotate=False):
         self.image_type = image_type
         self.ocr_name = ocr_name
         self.ocr_address = ocr_address
         self.images_path = images_path
         self.image_list = []
+        # chain 迭代器,首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素
+        # eg:chain('ABC', 'DEF') --> A B C D E F
+
         for p in chain(*[Path(self.images_path).rglob('*.jpg')]):
             if rotate:
-                self.image_list.extend(Image(p, r) for r in [None, 0, 1, 2])
+                self.image_list.extend(Image(p, r, rotate) for r in [None, 0, 1, 2])
             else:
-                self.image_list.append(Image(p, None))
+                self.image_list.append(Image(p, None, rotate))
 
-        self.field = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
-        # if self.image_type:
-        #     self.field = ['orientation', 'type', 'address', 'address_province', 'address_city', 'address_region',
-        #                   'address_detail']
-        # else:
-        #     self.field = ['orientation', 'name', 'id', 'gender', 'birthplace', 'birthplace_province', 'birthplace_city',
-        #                   'birthplace_region', 'native_place', 'native_place_province', 'native_place_city',
-        #                   'native_place_region', 'blood_type', 'religion']
+        self.field = Filed.get(field)
 
         self.correct = {k: 0 for k in self.field}
         self.error = {k: 0 for k in self.field}
@@ -136,6 +139,7 @@ class Dataset(object):
 
     def _evaluate_one(self, image: Image):
         def _get_predict(r, key):
+            # isinstance() 函数来判断一个对象是否是一个已知的类型
             if isinstance(r[key], dict):
                 return r[key]['text']
             else:

+ 34 - 2
ocr_config.py

@@ -19,7 +19,7 @@ class Configs:
     type: Type
 
 
-
+# cet
 cet_local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
 cet_test_config = RequestConfig(
     url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
@@ -35,7 +35,7 @@ CET_CONFIGS = {
 }
 
 # regbook
-regbook_local_config = RequestConfig(url='http://192.168.199.249:18020/ocr_system/regbook', token='')
+regbook_local_config = RequestConfig(url='http://192.168.199.249:18040/ocr_system/regbook', token='')
 regbook_test_config = RequestConfig(
     url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/hkbsb/regbook',
     token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
@@ -53,3 +53,35 @@ OCR_CONFIGS = {
     'cet': CET_CONFIGS,
     'regbook': REGBOOK_CONFIGS
 }
+
+# business_license
+business_license_local_config = RequestConfig(url='http://192.168.199.249:18060/ocr_system/business_license', token='')
+business_license_test_config = RequestConfig(
+    url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/hkbsb/regbook',
+    token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
+business_license_sb_config = RequestConfig(
+    url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/hkbsb/regbook',
+    token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
+
+BUSINESS_CONFIGS = {
+    'local': business_license_local_config,
+    'test': business_license_test_config,
+    'sb': business_license_sb_config
+}
+
+OCR_CONFIGS = {
+    'cet': CET_CONFIGS,
+    'regbook': REGBOOK_CONFIGS,
+    'business_license': BUSINESS_CONFIGS
+}
+
+# 字段
+cet_field = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
+regbook_field = ['orientation', 'name', 'id', 'gender', 'birthplace', 'birthplace_province', 'birthplace_city',
+                          'birthplace_region', 'native_place', 'native_place_province', 'native_place_city',
+                          'native_place_region', 'blood_type', 'religion']
+
+Filed = {
+    'cet': cet_field,
+    'regbook': regbook_field
+}

+ 6 - 7
use.py

@@ -5,13 +5,14 @@ from new import MD, Image, Dataset, parser_path
 
 # config
 # 图片路径
-image_path = Path('/Users/zeke/work/sx/OCR/HROCR/to_md/example/img')
-image_type = None
+image_path = Path('/Users/zeke/Downloads/9.1/专四/img')
+image_type = 0
 # 是否旋转
-image_rotate = True
+image_rotate = False
 ocr_address = 'local'  # 'local' 'test' 'sb'
 ocr_name = 'cet'  # 'cet' 'idcard' 'bankcard' 'regbook' 'schoolcert'
-md_name = 'CET-tem'
+md_name = 'CET'
+filed = 'cet'
 # 若md_path为None 则默认使用图片父路径为markdown保存路径
 # md_path = '/Users/zeke/work/sx/OCR/HROCR/to_md/example' or image_path.parent
 md_path = None or image_path.parent
@@ -19,12 +20,10 @@ md_path = None or image_path.parent
 md_file = parser_path(Path(md_path) / Path(md_name), image_rotate)
 
 
-
-
 if __name__ == '__main__':
     markdown = MD(md_file)
 
-    dataset = Dataset(image_path, image_type, ocr_name, ocr_address, image_rotate)
+    dataset = Dataset(image_path, image_type, ocr_name, ocr_address, filed, image_rotate)
     print(len(dataset))
     for d in dataset():
         print(d)