Jelajahi Sumber

更新 后处理 模型替换form jupyter

zeke-chin 2 tahun lalu
induk
melakukan
dcff76da74
8 mengubah file dengan 157 tambahan dan 291 penghapusan
  1. 1 0
      .gitignore
  2. 143 7
      core/line_parser.py
  3. 1 0
      core/ocr.py
  4. 10 4
      core/parser.py
  5. 0 28
      markdown/08-17英语等级证书_CET4.md
  6. 0 222
      markdown/new.py
  7. 0 27
      markdown/use_new.py
  8. 2 3
      server.py

+ 1 - 0
.gitignore

@@ -1,5 +1,6 @@
 .DS_Store
 .idea
+.vscode
 convert_*
 generate_test.py
 /images/test

+ 143 - 7
core/line_parser.py

@@ -56,7 +56,143 @@ class OcrResult(object):
         x_idx = 1 - y_idx
         if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
         if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
-        eps = 0.35 * (self.wh[y_idx] + b.wh[y_idx])
+        eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
+        dist = abs(self.center[y_idx] - b.center[y_idx])
+        return dist < eps
+
+
+# 行处理器
+
+class LineParser(object):
+    def __init__(self, ocr_raw_result, filters=None):
+        # if filters is None:
+        #     filters = [lambda x: x.is_slope]
+        self.ocr_res = []
+        for re in ocr_raw_result:
+            o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
+            if any([f(o) for f in filters]): continue
+            self.ocr_res.append(o)
+        # for f in filters:
+        #     self.ocr_res = list(filter(f, self.ocr_res))
+        self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
+        self.eps = self.avg_height * 0.7
+
+    @property
+    def is_horizontal(self):
+        res = self.ocr_res
+        wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
+        return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
+
+    @property
+    def avg_height(self):
+        idx = self.is_horizontal + 0
+        return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
+
+    # 整体置信度
+    @property
+    def confidence(self):
+        return np.mean([r.conf for r in self.ocr_res])
+
+    # 处理器函数
+    # @sxtimeit
+    def parse(self, eps=40.0):
+        # 存返回值
+        res = []
+
+        # 需要 处理的 OcrResult 对象  的长度
+        length = len(self.ocr_res)
+        print('length: ', length)
+
+        # 如果字段数 小于等于1 就抛出异常
+        if length <= 1:
+            raise Exception('无法识别')
+
+        in_lines = set()
+        # 遍历数组 并处理他
+        for i in range(length):
+            # print('in lines', in_lines)
+            # 拿出 OcrResult对象的 第i值 -暂存-
+            res_i = self.ocr_res[i]
+
+            # 这次的 res_i 之前已经在结果集中,就继续下一个
+            # if any(map(lambda x: res_i in x, res)): continue
+            if i in in_lines: continue
+            # set() -> {}
+            # 初始化一个集合 即-输出-
+            res_row = set()
+
+            for j in range(i, length):
+                res_j = self.ocr_res[j]
+                # 这次的 res_i 之前已经在结果集中,就继续下一个
+                # if any(map(lambda x: res_j in x, res)): continue
+                if j in in_lines: continue
+                if res_i.one_line(res_j, self.is_horizontal, self.eps):
+                    # LineParser 对象  不可以直接加入字典
+
+                    res_row.add(res_j)
+                    in_lines.add(j)
+            res.append(res_row)
+        idx = self.is_horizontal + 0
+        return sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
+import numpy as np
+from dataclasses import dataclass
+
+
+# result 对象
+@dataclass
+class OcrResult(object):
+    box: np.ndarray
+    txt: str
+    conf: float
+
+    def __hash__(self):
+        return hash(repr(self))
+
+    def __repr__(self):
+        return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
+
+    @property
+    def lt(self):
+        l, t = np.min(self.box, 0)
+        return [l, t]
+
+    @property
+    def rb(self):
+        r, b = np.max(self.box, 0)
+        return [r, b]
+
+    @property
+    def wh(self):
+        l, t = self.lt
+        r, b = self.rb
+        return [r - l, b - t]
+
+    @property
+    def area(self):
+        w, h = self.wh
+        return w * h
+
+    @property
+    def is_slope(self):
+        p0 = self.box[0]
+        p1 = self.box[1]
+        if p0[0] == p1[0]:
+            return False
+        slope = abs(1. * (p0[1] - p1[1]) / (p0[0] - p1[0]))
+        return 0.4 < slope < 2.5
+
+    @property
+    def center(self):
+        l, t = self.lt
+        r, b = self.rb
+        return [(r + l) / 2, (b + t) / 2]
+
+    def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
+        y_idx = 0 + is_horizontal
+        x_idx = 1 - y_idx
+        if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
+        if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
+        eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
         dist = abs(self.center[y_idx] - b.center[y_idx])
         return dist < eps
 
@@ -75,6 +211,11 @@ class LineParser(object):
         #     self.ocr_res = list(filter(f, self.ocr_res))
         self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
         self.eps = self.avg_height * 0.7
+        # self.ocr_res = []
+        # for re in ocr_raw_result:
+        #     o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
+        #     self.ocr_res.append(o)
+        # self.eps = self.avg_height * 0.86
 
     @property
     def is_horizontal(self):
@@ -127,9 +268,4 @@ class LineParser(object):
                     res_row.add(res_j)
             res.append(res_row)
         idx = self.is_horizontal + 0
-        res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
-        for row in res:
-            print('---')
-            print(''.join([r.txt for r in row]))
-        return res
-
+        return sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])

+ 1 - 0
core/ocr.py

@@ -28,6 +28,7 @@ class CetOcr:
 
         # 旋转后img angle result(生ocr)
         image, angle, result, image_type = self._pre_process(image)
+        cv2.imwrite('dd.jpg', image)
         print(f'---------- detect angle: {angle} 角度 --------')
         if angle != 0:
             _, _, result = self._ocr(image)

+ 10 - 4
core/parser.py

@@ -93,7 +93,10 @@ class CETParser(Parser):
             txt = res[-1][0]
             conf = res[-1][1]
 
-            id_num = re.findall("\d{10,18}[X|x|×]*", txt)
+            id_num = re.findall("\d{17,19}[X|x|×]*", txt)
+            if id_num and len(id_num[0]) == 19 and id_num[0][0] == id_num[0][1]:
+                self.res['id'] = RecItem(id_num[0][1:], conf)
+                break
             if id_num and len(id_num[0]) == 18:
                 self.res['id'] = RecItem(id_num[0].replace('x', "X").replace('×', "X"), conf)
                 break
@@ -151,7 +154,7 @@ class CETParser(Parser):
                 if len(score[0]) == 4 and score[0][0] == score[0][1]:
                     self.res["score"] = RecItem(score[0][1:], conf)
                     return
-                self.res["score"] = RecItem(txt[2:], conf)
+                self.res["score"] = RecItem(score[0], conf)
                 return
 
         for i in range(len(self.result)):
@@ -160,8 +163,11 @@ class CETParser(Parser):
             conf = res[-1][1]
 
             if "时间" in txt:
-                txt = txt.split("月")[-1][:3]
-                self.res["score"] = RecItem(txt, conf)
+                if '月' in txt:
+                    txt = txt.split("月")[-1][:3]
+                    self.res["score"] = RecItem(txt, conf)
+                else:
+                    self.res["score"] = RecItem(res[1].txt, conf)
                 return
 
     def to_data(self, txt):

+ 0 - 28
markdown/08-17英语等级证书_CET4.md

@@ -1,28 +0,0 @@
-
-
-
-
-# 英语等级证书_CET4
-
-
-100.00%
-# True
-
-|图片|识别结果|
-| :---: | :---: |
-|![](/Users/zeke/work/sx/OCR/HROCR/hr-ocr-cet/images/cet6/1_img.jpg)|{'orientation': 0, 'name': '姚学娇', 'id': '150221199909170340', 'language': '英语', 'level': 'CET6', 'exam_time': '2021年6月', 'score': '428'}|
-|![](/Users/zeke/work/sx/OCR/HROCR/hr-ocr-cet/images/cet6/2_img.jpg)|{'orientation': 0, 'name': '潘奕锦', 'id': '150802199610248725', 'language': '英语', 'level': 'CET6', 'exam_time': '2018年12月', 'score': '504'}|
-|![](/Users/zeke/work/sx/OCR/HROCR/hr-ocr-cet/images/cet6/3_img.jpg)|{'orientation': 0, 'name': '吕昭颖', 'id': '450922199610094625', 'language': '英语', 'level': 'CET6', 'exam_time': '2016年06月', 'score': '442'}|
-|![](/Users/zeke/work/sx/OCR/HROCR/hr-ocr-cet/images/cet6/4_img.jpg)|{'orientation': 0, 'name': '苏朋阳', 'id': '130525199903305118', 'language': '英语', 'level': 'CET6', 'exam_time': '2019年6月', 'score': '436'}|
-
-# False
-
-|字段|正确率|
-| :---: | :---: |
-|orientation|100.00%|
-|name|100.00%|
-|id|100.00%|
-|language|100.00%|
-|level|100.00%|
-|exam_time|100.00%|
-|score|100.00%|

+ 0 - 222
markdown/new.py

@@ -1,222 +0,0 @@
-from pathlib import Path
-from typing import List, Optional
-import cv2
-import requests
-from dataclasses import dataclass
-import json
-import time
-import base64
-from itertools import chain
-from tqdm import tqdm
-
-
-@dataclass
-class RequestConfig:
-    url: str
-    token: str
-
-
-local_config = RequestConfig(url='http://192.168.199.249:18050/ocr_system/cet', token='')
-test_config = RequestConfig(
-    url='http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
-    token='9679c2b3-b90b-4029-a3c7-f347b4d242f7')
-sb_config = RequestConfig(url='http://aihub.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/cettest/cet',
-                          token='dcae8cc6-0e49-4db8-a2d2-94ef84da3636')
-
-CONFIGS = {
-    'local': local_config,
-    'test': test_config,
-    'sb': sb_config
-}
-
-CONFIG_STR = 'local'
-
-IMAGE_TYPE = None
-
-IMAGE_PATH = Path('images/cet6/')
-
-
-# class MarkdownTable(object):
-#     def __init__(self, name):
-#         self.name = name
-#         self.mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + name)
-#         self.field_table = ['字段', '正确率']
-#         self.true_table = ['图片', '识别结果']
-#         self.false_table = ['图片', '识别结果']
-
-#     def add_field_table(self, fields: List):
-#         self.field_table.extend(fields)
-
-#     def add_true_table(self, image_and_field: List):
-#         self.true_table.extend(image_and_field)
-
-#     def add_false_table(self, image_and_field: List):
-#         self.false_table.extend(image_and_field)
-
-
-class Image:
-    def __init__(self, path: Path, rotate):
-        self._path = path
-        self.rotate = rotate
-        self._ocr_result = None
-        self.cate = True
-        try:
-            self.gt_result = self.get_json()
-        except Exception as e:
-            print(self.json_path)
-            raise e
-
-    def __repr__(self):
-        return f'path: {self.path}, rotate: {self.rotate}, gt_result: {self.gt_result}, cate: {self.cate}'
-
-    @property
-    def path(self):
-        return self._path
-
-    @path.setter
-    def path(self, path):
-        self._path = path
-
-    @property
-    def fn(self):
-        return self._path.stem
-
-    @property
-    def ocr_result(self):
-        return self._ocr_result
-
-    @ocr_result.setter
-    def ocr_result(self, value):
-        self._ocr_result = value
-
-    def get_gt_result(self, key):
-        if key == 'orientation':
-            return self.rotate + 1 if self.rotate is not None else 0
-        elif key in self.gt_result:
-            return self.gt_result[key]
-        else:
-            return None
-
-
-    @property
-    def json_path(self):
-        return self.path.parent / f'{self.path.stem}.json'
-
-    def save_image(self, img, rotate):
-        dst = self.path.parent.parent / (".ro_dst")
-        if not dst.exists(): dst.mkdir()
-        self.path = dst / f'{self.path.stem}-{rotate+1}.jpg'
-        print('save image', self.path)
-        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-        cv2.imwrite(str(self.path), img)
-
-
-    def get_base64(self, rotate=None):
-        print(self.path)
-        img = cv2.imread(str(self.path))
-        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-        if rotate is not None:
-            img = cv2.rotate(img, rotate)
-            self.save_image(img, rotate)
-        _, img = cv2.imencode('.jpg', img)
-        img_str = base64.b64encode(img).decode('utf-8')
-        return img_str
-
-    def get_json(self):
-        with open(self.json_path, 'r') as f:
-            return json.load(f)
-
-
-def send_request(image: Image, config_str, image_type=None):
-    base64_str = image.get_base64(image.rotate)
-    config = CONFIGS[config_str]
-    headers = {
-        'Content-Type': 'application/json',
-        'Authorization': config.token
-    }
-    data = {
-        'image': base64_str,
-    }
-    if image_type:
-        data['image_type'] = image_type
-    response = requests.post(config.url, headers=headers, json=data)
-    return response.json()
-
-
-class Dataset(object):
-    def __init__(self, image_path, image_type, config_str, rotate=False):
-        self.image_type = image_type
-        self.config_str = config_str
-        self.image_path = image_path
-        self.image_list = []
-        for p in chain(*[Path(self.image_path).rglob('*.jpg')]):
-            if rotate:
-                for r in [None, 0, 1, 2]:
-                    self.image_list.append(Image(p, r))
-            else:
-                self.image_list.append(Image(p, None))
-
-        self.attrs = ['orientation', 'name', 'id', 'language', 'level', 'exam_time', 'score']
-
-        self.correct = {k: 0 for k in self.attrs}
-        self.error = {k: 0 for k in self.attrs}
-
-    def __len__(self):
-        return len(self.image_list)
-
-    def _evaluate_one(self, image: Image):
-        def _get_predict(r, key):
-            if isinstance(r[key], dict):
-                return r[key]['text']
-            else:
-                return r[key]
-
-        r = send_request(image, self.config_str, self.image_type)
-        err_str = ''
-        if r['status'] == '000':
-            res = r['result']
-            for key in self.attrs:
-                print('attr: ', key)
-                if key in res:
-                    gt = image.get_gt_result(key)
-                    predict = _get_predict(res, key)
-                    print(f'gt: {gt}, predict: {predict}')
-                    if predict == gt:
-                        self.correct[key] += 1
-                    else:
-                        image.cate = False
-                        self.error[key] += 1
-                        err_str += f'正确:{gt}<br>返回:{predict}<br>'
-            if image.cate:
-                image.ocr_result = r['result']
-            else:
-                image.ocr_result = err_str
-        else:
-            image.ocr_result = r['msg']
-            image.cate = False
-            for key in self.attrs:
-                self.error[key] += 1
-
-    def __call__(self):
-        for image in self.image_list:
-            yield image
-
-    def evaluate(self):
-        for image in tqdm(self.image_list):
-            self._evaluate_one(image)
-
-    def accuracy(self):
-        return sum(list(self.correct.values())) / sum(list(self.correct.values()) + list(self.error.values()))
-
-    def attrs_accuracy(self):
-        return {k: self.correct[k] / (self.correct[k] + self.error[k]) for k in self.attrs}
-
-
-if __name__ == '__main__':
-    dataset = Dataset(IMAGE_PATH, IMAGE_TYPE, CONFIG_STR, True)
-    print(len(dataset))
-    for d in dataset():
-        print(d)
-
-    dataset.evaluate()
-    print(dataset.accuracy())

+ 0 - 27
markdown/use_new.py

@@ -1,27 +0,0 @@
-from pathlib import Path
-
-from markdown.new import MD, Image, Dataset
-
-# config
-image_path = Path('images/test/image')
-image_type = None
-image_rotate = True
-ocr_address = 'local'  # 'local' 'test' 'sb'
-ocr_name = 'cet'  # 'cet' 'idcard' 'bankcard' 'regbook' 'schoolcert'
-
-if __name__ == '__main__':
-    markdown = MD('英语等级证书')
-
-    dataset = Dataset(image_path, image_type, ocr_name, ocr_address, image_rotate)
-    print(len(dataset))
-    for d in dataset():
-        print(d)
-
-    dataset.evaluate()
-    print(dataset.accuracy)
-
-    markdown.write_total_accuracy(dataset)
-    markdown.write_table_accuracy(dataset)
-    markdown.write_table_result(dataset)
-
-    markdown.f.create_md_file()

+ 2 - 3
server.py

@@ -62,9 +62,8 @@ print(f'use gpu: {use_gpu}')
 #                 warmup=True)
 #
 ocr = PaddleOCR(use_angle_cls=True,
-                rec_model_dir='./server_model/ch_ppocr_server_v2.0_rec_infer',
-                det_model_dir='./server_model/ch_ppocr_server_v2.0_det_infer',
-                cls_model_dir='./idcard_cls_infer',
+                rec_model_dir='./server_model/ch_ppocr_server_v2.0_rec_infer/',
+                det_model_dir='./server_model/ch_ppocr_server_v2.0_det_infer/',
                 ocr_version='PP-OCRv2',
                 rec_algorithm='CRNN',
                 use_gpu=use_gpu,