Browse Source

update detect model & rec model

Raychar 2 years ago
parent
commit
5b1a2188fc

+ 2 - 1
.gitignore

@@ -3,4 +3,5 @@
 .DS_Store
 .DS_Store
 __pycache__
 __pycache__
 *.pyc
 *.pyc
-output/
+output/
+/images/test

+ 3 - 0
Dockerfile

@@ -97,6 +97,9 @@ environment=PYTHONUNBUFFERED=1\n\
 ARG VERSION
 ARG VERSION
 ENV USE_CUDA $VERSION
 ENV USE_CUDA $VERSION
 Add . /workspace
 Add . /workspace
+RUN cp predict_det.py /opt/conda/envs/py38/lib/python3.8/site-packages/paddleocr/tools/infer/predict_det.py
+RUN cp utility.py /opt/conda/envs/py38/lib/python3.8/site-packages/paddleocr/tools/infer/utility.py
+
 EXPOSE 8080
 EXPOSE 8080
 
 
 
 

BIN
bank_det_infer/inference.pdiparams


BIN
bank_det_infer/inference.pdiparams.info


BIN
bank_det_infer/inference.pdmodel


BIN
bank_rec_infer/inference.pdiparams


BIN
bank_rec_infer/inference.pdiparams.info


BIN
bank_rec_infer/inference.pdmodel


+ 120 - 0
convert_markdown.py

@@ -0,0 +1,120 @@
+import operator
+from pathlib import Path
+
+import numpy as np
+from mdutils.mdutils import MdUtils
+import cv2
+import requests
+import json
+import time
+import base64
+from itertools import chain
+
+url = 'http://192.168.199.249:28811'
+# url = "http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm/xxw/schoolcert"
+root = Path(__file__).parent
+mdFile = MdUtils(file_name=time.strftime("%m-%d", time.localtime()) + "银行卡")
+
+print(root)
+
+
+def send_request(img_path, image_type=0, rotate=None):
+    # sourcery skip: use-fstring-for-concatenation
+    # or_img
+    # 创建存旋转照片的目录
+    dir = img_path.parent.parent / (".ro_dir")
+    if not dir.exists(): dir.mkdir()
+
+    # 旋转文件存文件
+    #   - 读文件
+    img = cv2.imread(str(img_path))
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    #   - 旋转保存文件
+    # if rotate is not None:
+    #     img = cv2.rotate(img, rotate)
+    #
+    #     angle = "_" + str(rotate + 1)
+    #     img_path = dir / (img_path.stem + angle + ".jpg")
+    #     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    #     cv2.imwrite(str(img_path), img)
+
+    with img_path.open('rb') as f:
+        img_str: str = base64.encodebytes(f.read()).decode('utf-8')
+
+    # request
+    # headers = {
+    #     'Authorization': "Bearer 9679c2b3-b90b-4029-a3c7-f347b4d242f7",
+    #     'content-type': "application/json"
+    # }
+    # r = requests.post(url, json={"image": img_str, "image_type": image_type}, headers=headers)
+
+    r = requests.post(url + '/ocr_system/bankcard', json={"image": img_str, "image_type": image_type})
+
+
+    print(r.json())
+    return r.json(), img_path
+
+
+def _parse_result(r):
+    if r['status'] == '000':
+        r = r['result']
+        del r['confidence']
+        return {k: v['text'] if isinstance(v, dict) else v for k, v in r.items()}
+    elif r['status'] == '101':
+        return r['msg']
+
+
+
+def compare_dic(dic, dic1, mdimg_path):
+    global true_num
+    image = mdFile.new_inline_image(text='', path=mdimg_path)
+    if operator.eq(dic, dic1):
+        true_list.extend([image, dic1])
+    elif type(dic1) == dict:
+        err_str = ""
+        for key in dic:
+            if dic[key] != dic1[key]:
+                err_str = f"{err_str}正确:{dic[key]}<br>返回:{dic1[key]}<br>"
+                true_num = true_num - 1
+        false_list.extend([image, err_str])
+    elif type(dic1) == str:
+        false_list.extend([image, dic1])
+
+
+if __name__ == '__main__':
+    true_list = ["图片", "识别结果"]
+    false_list = ["图片", "识别结果"]
+
+    # img_paths = chain(*[Path('./images/test/7.28/0').rglob('*.jpg')])
+    img_paths = list(chain(*[Path('./images/test/7.28/0').rglob('*.jpg')]))
+    true_num = all_num = 2 * len(img_paths)
+    # 文件夹里每个图片
+    for img_path in img_paths:
+        print(img_path)
+
+        # 打开该图片的json文件
+        imgj_path_d = img_path.parent
+        imgj_path_n = f'{img_path.stem}.json'
+        imgj_path = imgj_path_d / imgj_path_n
+        with imgj_path.open('r') as json_f:
+            dic_j = json.load(json_f)
+
+        # 四个方向
+        # for (oor, cv2_orr) in {0: None, 1: 0, 2: 1, 3: 2}.items():
+        #     r, path = send_request(root / img_path, 0, cv2_orr)
+        #     dic_pic = _parse_result(r)
+        #     dic_j['orientation'] = oor
+        #     compare_dic(dic_j, dic_pic, str(path))
+        r, path = send_request(root / img_path, 0, 0)
+        dic_pic = _parse_result(r)
+        compare_dic(dic_j, dic_pic, str(path))
+
+    mdFile.new_header(level=1, title='测试正确率')
+    mdFile.new_paragraph("{:.2f}%".format(true_num / all_num * 100))
+    mdFile.new_header(level=1, title='True')
+    mdFile.new_table(columns=2, rows=len(true_list) // 2, text=true_list, text_align='center')
+
+    mdFile.new_header(level=1, title='False')
+    mdFile.new_table(columns=2, rows=len(false_list) // 2, text=false_list, text_align='center')
+
+    mdFile.create_md_file()

+ 35 - 7
core/anchor.py

@@ -71,7 +71,7 @@ class OcrAnchor(object):
         return (l + r) / 2, (t + b) / 2
         return (l + r) / 2, (t + b) / 2
 
 
     # 判断是否是 锚点
     # 判断是否是 锚点
-    def is_anchor(self, txt, box) -> bool:
+    def is_anchor(self, txt, box, conf) -> bool:
         pass
         pass
 
 
     # 找 锚点 -> 锚点坐标
     # 找 锚点 -> 锚点坐标
@@ -83,10 +83,22 @@ class OcrAnchor(object):
             for r in row:
             for r in row:
                 txt = r.txt.replace('-', '').replace(' ', '')
                 txt = r.txt.replace('-', '').replace(' ', '')
                 box = r.box
                 box = r.box
-                if self.is_anchor(txt, box):
+                conf = r.conf
+                flag = self.is_anchor(txt, box, conf)
+                if flag:
                     l, t = np.min(box, 0)
                     l, t = np.min(box, 0)
                     r, b = np.max(box, 0)
                     r, b = np.max(box, 0)
                     return True, (l + r) / 2, (t + b) / 2
                     return True, (l + r) / 2, (t + b) / 2
+        #         if flag and (len(re.findall('\d{10,20}', txt)) > 0 and conf > 0.95):
+        #             l, t = np.min(box, 0)
+        #             r, b = np.max(box, 0)
+        #             return True, (l + r) / 2, (t + b) / 2
+        #         elif flag:
+        #             l, t = np.min(box, 0)
+        #             r, b = np.max(box, 0)
+        # if l:
+        #     return True, (l + r) / 2, (t + b) / 2
+        # else:
         return False, 0., 0.
         return False, 0., 0.
 
 
     # 定位 锚点 -> 角度
     # 定位 锚点 -> 角度
@@ -97,8 +109,6 @@ class OcrAnchor(object):
         # 如果识别不到身份证号
         # 如果识别不到身份证号
         if not found: raise Exception(f'识别不到anchor{self.name}')
         if not found: raise Exception(f'识别不到anchor{self.name}')
         cx, cy = self.get_rec_area(res)
         cx, cy = self.get_rec_area(res)
-        # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
-        # print(f'cx: {cx}, cy: {cy}')
         pre = None
         pre = None
         for d in self.direction:
         for d in self.direction:
             f = self.direction_funcs.get(d, None)
             f = self.direction_funcs.get(d, None)
@@ -115,10 +125,28 @@ class BankCardAnchor(OcrAnchor):
     def __init__(self, name: str, d: List[Direction]):
     def __init__(self, name: str, d: List[Direction]):
         super(BankCardAnchor, self).__init__(name, d)
         super(BankCardAnchor, self).__init__(name, d)
 
 
-    def is_anchor(self, txt, box) -> bool:
-        txts = re.findall('\d{10,20}', txt)
-        if len(txts) > 0:
+    def is_anchor(self, txt, box, conf) -> bool:
+        # # 这边我动了手脚,可能需要改一下长度,到时候测试再看
+        # txts = re.findall('\d{5,20}', txt)
+        # # print(txts)
+        # if conf > 0.95 and len(txts) > 0:
+        #     # print("这是我识别出来的卡号:", txts)
+        #     return True
+
+        # 这里逻辑有点长,理想情况下,置信度比较高的txt会在卡号附近,一般在卡号下方
+        if len(re.findall('\d{16,20}', txt)) > 0 and conf > 0.95:  # 完美找到卡号
+            return True
+        elif len(re.findall('\d{10,16}', txt)) > 0 and conf > 0.95:  # 卡号只找到了一半多点
+            return True
+        elif len(re.findall('\d{6,10}', txt)) > 0 and conf > 0.95:  # 卡号 只找到了一点
             return True
             return True
+        elif len(re.findall('\d{4,6}', txt)) > 0 and conf > 0.95:  # 卡号只找到了一丢丢
+            return True
+        elif conf > 0.95:  # 可能卡号就是找到了一个数字,但是置信度很高,
+            return True
+        # elif conf >= 0.9:
+        #     return True
+
         return False
         return False
 
 
     def locate_anchor(self, res, is_horizontal) -> int:
     def locate_anchor(self, res, is_horizontal) -> int:

+ 6 - 1
core/direction.py

@@ -24,7 +24,6 @@ class AngleDetector(object):
 
 
     def detect_angle(self, img):
     def detect_angle(self, img):
         ocr_anchor = BankCardAnchor('银行卡号', [Direction.BOTTOM])
         ocr_anchor = BankCardAnchor('银行卡号', [Direction.BOTTOM])
-
         result = self.ocr.ocr(img, cls=True)
         result = self.ocr.ocr(img, cls=True)
 
 
         try:
         try:
@@ -32,6 +31,7 @@ class AngleDetector(object):
             return angle, result
             return angle, result
 
 
         except Exception as e:
         except Exception as e:
+            print("direction.py这里有异常。。。。。。")
             print(e)
             print(e)
             # 如果第一次识别不到,旋转90度再识别
             # 如果第一次识别不到,旋转90度再识别
             img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
             img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
@@ -39,3 +39,8 @@ class AngleDetector(object):
             angle = detect_angle(result, ocr_anchor)
             angle = detect_angle(result, ocr_anchor)
             # 旋转90度之后要重新计算角度
             # 旋转90度之后要重新计算角度
             return (angle - 1 + 4) % 4, result
             return (angle - 1 + 4) % 4, result
+
+    def origin_detect(self, img):
+        # 这边一般是在自己的检测模型result=[]时,再使用官方的模型做个检测,如果这个也没有结果,那就真的检测不出来
+        result = self.ocr.ocr(img)
+        return result

+ 5 - 3
core/line_parser.py

@@ -1,6 +1,7 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 import numpy as np
 import numpy as np
 
 
+
 # result 对象
 # result 对象
 @dataclass
 @dataclass
 class OcrResult(object):
 class OcrResult(object):
@@ -45,6 +46,7 @@ class OcrResult(object):
         dist = abs(self.center[y_idx] - b.center[y_idx])
         dist = abs(self.center[y_idx] - b.center[y_idx])
         return dist < eps
         return dist < eps
 
 
+
 # 行处理器
 # 行处理器
 class LineParser(object):
 class LineParser(object):
     def __init__(self, ocr_raw_result):
     def __init__(self, ocr_raw_result):
@@ -78,8 +80,8 @@ class LineParser(object):
         length = len(self.ocr_res)
         length = len(self.ocr_res)
 
 
         # 如果字段数 小于等于1 就抛出异常
         # 如果字段数 小于等于1 就抛出异常
-        if length <= 1:
-            raise Exception('无法识别')
+        # if length <= 1:
+        #     raise Exception('无法识别')
 
 
         # 遍历数组 并处理他
         # 遍历数组 并处理他
         for i in range(length):
         for i in range(length):
@@ -100,4 +102,4 @@ class LineParser(object):
                     res_row.add(res_j)
                     res_row.add(res_j)
             res.append(res_row)
             res.append(res_row)
         idx = self.is_horizontal + 0
         idx = self.is_horizontal + 0
-        return sorted([sorted(list(r), key=lambda x: x.lt[1-idx]) for r in res], key=lambda x: x[0].lt[idx])
+        return sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])

+ 66 - 12
core/ocr.py

@@ -1,12 +1,14 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 
 
 import numpy as np
 import numpy as np
-from paddleocr import PaddleOCR
+from paddleocr import PaddleOCR, draw_ocr
 
 
 from core.direction import *
 from core.direction import *
 from core.line_parser import LineParser
 from core.line_parser import LineParser
 from core.parser import *
 from core.parser import *
 
 
+from PIL import Image
+
 
 
 @dataclass
 @dataclass
 class BankOcr:
 class BankOcr:
@@ -14,19 +16,29 @@ class BankOcr:
     angle_detector: AngleDetector
     angle_detector: AngleDetector
 
 
     def predict(self, image: np.ndarray):
     def predict(self, image: np.ndarray):
-        image, angle, result = self._pre_process(image)
+        image, angle, ori_result = self._pre_process(image)
         print(f'---------- detect angle: {angle} 角度 --------')
         print(f'---------- detect angle: {angle} 角度 --------')
-        if angle != 0:
-            # 角度不为0需要重新识别,字面
-            _, _, result = self._ocr(image)
+        # 这里使用自己训练的检测识别模型,在此之前,理想情况下,所有的银行卡的角度都已经是0,(正向)
+        _, _, result = self._ocr(image)
+
+        # self.imshow(image, result)  # 将检测图片保存
         return self._post_process(result, angle)
         return self._post_process(result, angle)
 
 
+    def imshow(self, image, result):
+        img = Image.fromarray(image).convert("RGB")
+        boxes = [line[0] for line in result]
+        txts = [line[1][0] for line in result]
+        scores = [line[1][1] for line in result]
+        im_show = draw_ocr(img, boxes, txts, scores, font_path="./simfang.ttf")
+        im_show = Image.fromarray(im_show)
+        im_show.save("./img.jpg")
+
     def _pre_process(self, image: np.ndarray):
     def _pre_process(self, image: np.ndarray):
         angle, result = self.angle_detector.detect_angle(image)
         angle, result = self.angle_detector.detect_angle(image)
 
 
         if angle == 1:
         if angle == 1:
             image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
             image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
-        print(angle)  # 逆时针
+        # print("检测出来的角度:", angle)  # 逆时针
         if angle == 2:
         if angle == 2:
             image = cv2.rotate(image, cv2.ROTATE_180)
             image = cv2.rotate(image, cv2.ROTATE_180)
         if angle == 3:
         if angle == 3:
@@ -35,20 +47,62 @@ class BankOcr:
         return image, angle, result
         return image, angle, result
 
 
     def _ocr(self, image):
     def _ocr(self, image):
-        # 获取模型检测结果
-        result = self.ocr.ocr(image, cls=True)
+        # 获取模型检测结果,因为是正的照片了,所以不需要方向分类器
+        result = self.ocr.ocr(image, cls=False)
         print("------------------")
         print("------------------")
-        print(result)
+        print("result:", result)
+        print("------------------")
+
+        # result=[] 就用官方再检测
         if not result:
         if not result:
-            raise Exception('无法识别')
-        confs = [line[1][1] for line in result]
+            print("需要再次进行官方的检测代码。。。。。。。。。。。。")
+            result = self.angle_detector.origin_detect(image)
+            # 如果还是空,那就检测不出来
+            if not result:
+                raise Exception('经过两次检测都无法识别!!!')
+
+            confs = [line[1][1] for line in result]
+            txts = [line[1][0] for line in result]
+            return txts, confs, result
 
 
+        # result!=[] 就判断一些规则
+        if result:
+            confs = [line[1][1] for line in result]
+            print("自己的检测模型得到的conf:", confs)
+            if len(result) == 2 and all(map(lambda x: x > 0.975, confs)):
+                l_box, r_box = [], []
+                l_box.extend(result[0][0])
+                r_box.extend(result[1][0])
+
+                l_max, _ = np.max(l_box, 0)
+                r_min, _ = np.min(r_box, 0)
+
+                if l_max > r_min:
+                    print("说明自己的检测模型不好")
+                    result = self.angle_detector.origin_detect(image)
+            else:
+                # 一般情况下,len=1
+                flag = 0
+                if map(lambda x: x >= 0.975, confs):
+                    flag = 1
+                # for conf in confs:
+                #     if conf >= 0.975:
+                #         flag = 1
+                #         break
+                if flag == 0:
+                    print("需要再次进行官方的检测代码。。。。。。。。。。。。")
+                    result = self.angle_detector.origin_detect(image)
+
+        # 如果还是空,那就检测不出来
+        if not result:
+            raise Exception('经过两次检测都无法识别!!!')
+
+        confs = [line[1][1] for line in result]
         # 将检测到的文字放到一个列表中
         # 将检测到的文字放到一个列表中
         txts = [line[1][0] for line in result]
         txts = [line[1][0] for line in result]
         return txts, confs, result
         return txts, confs, result
 
 
     def _post_process(self, raw_result, angle: int):
     def _post_process(self, raw_result, angle: int):
-
         # 把测试图片 喂给 OCR 返回给 self.raw_results
         # 把测试图片 喂给 OCR 返回给 self.raw_results
         line_parser = LineParser(raw_result)
         line_parser = LineParser(raw_result)
         line_results = line_parser.parse()
         line_results = line_parser.parse()

+ 29 - 40
core/parser.py

@@ -19,49 +19,39 @@ class RecItem:
 
 
 
 
 def find_card_row(line_results):
 def find_card_row(line_results):
-    res = None
+    print('~~~~line results~~~~~')
     for row in line_results:
     for row in line_results:
-        row = sorted(row, key=lambda x: x.lt[0])
-        txt = [r.txt.replace(' ', '').replace('.', '') for r in row]
+        print('++++')
+        print(row)
+    print('~~~~line results~~~~~')
+
+    new_lines = []
+    for row in line_results:
+        new_line = []
+        for r in row:
+            if r.conf > 0.93:
+                new_line.append(r)
+        if new_line:
+            new_lines.append(new_line)
+
+    print('~~~~new line results~~~~~')
+    for row in new_lines:
+        print('++++')
+        print(row)
+    print('~~~~new line results~~~~~')
+    line_results = new_lines
+
+    for row in line_results:
+        txt = [r.txt.replace(' ', '') for r in row]
         conf = np.mean([r.conf for r in row])
         conf = np.mean([r.conf for r in row])
-        lts = [r.lt for r in row]
-        rbs = [r.rb for r in row]
-        lt = np.min(np.stack(lts), 0)
-        rb = np.max(np.stack(rbs), 0)
         txt = ''.join(txt)
         txt = ''.join(txt)
         res = re.findall('\d{15,20}', txt)
         res = re.findall('\d{15,20}', txt)
-        if res:
-            return row, res[0], conf, lt.astype(np.int).tolist(), rb.astype(np.int).tolist()
-    if not res:
-        res_lt, res_rb = None, None
-        row_res = 0
-        max_w, max_h = 0, 0
-        conf_res = 0.
-        for row in line_results:
-            txt = ''.join([r.txt.replace(' ', '').replace('.', '') for r in row])
-            conf = np.mean([r.conf for r in row])
-            print(txt)
-            if not txt.isascii(): continue
-            lts = [r.lt for r in row]
-            rbs = [r.rb for r in row]
-            lt = np.min(np.stack(lts), 0)
-            rb = np.max(np.stack(rbs), 0)
-            print(lt, rb, '-------')
-            w, h = (rb - lt).astype(np.int).tolist()
-            print(w, h, '-------')
-            if w > max_w:
-                row_res = row
-                max_w, max_h = w, h
-                res_lt, res_rb = lt, rb
-                conf_res = conf
-        res = re.findall('\d{15,20}', txt)
-        if res:
-            return row_res, txt, conf_res, res_lt.astype(np.int).tolist(), res_rb.astype(np.int).tolist()
-        else:
-            print('无法识别', txt)
-
+        print(f'res: {res}, conf: {conf}')
+        if res and conf > 0.95:
+            return row, res[0], conf
     raise Exception('无法识别')
     raise Exception('无法识别')
 
 
+
 def handle_wrong_digits(s):
 def handle_wrong_digits(s):
     s = s.replace(' ', '')
     s = s.replace(' ', '')
     s = s.replace('-', '')
     s = s.replace('-', '')
@@ -74,6 +64,7 @@ def handle_wrong_digits(s):
         s = '6' + s[1:]
         s = '6' + s[1:]
     return s
     return s
 
 
+
 class Parser(object):
 class Parser(object):
     def __init__(self, line_results: List[List[OcrResult]]):
     def __init__(self, line_results: List[List[OcrResult]]):
         self.line_results = line_results
         self.line_results = line_results
@@ -82,13 +73,11 @@ class Parser(object):
 
 
     def bank_no(self):
     def bank_no(self):
         # 号码
         # 号码
-        row, txt, conf, lt, rb = find_card_row(self.line_results)
+        row, txt, conf = find_card_row(self.line_results)
         print(f'=== txt: {txt}, res: {row}======')
         print(f'=== txt: {txt}, res: {row}======')
         txt = handle_wrong_digits(txt)
         txt = handle_wrong_digits(txt)
         self.res['number'] = RecItem(txt, conf)
         self.res['number'] = RecItem(txt, conf)
 
 
-
-
     def parse(self):
     def parse(self):
         self.bank_no()
         self.bank_no()
         return self.res
         return self.res

+ 3 - 0
cpu.Dockerfile

@@ -95,6 +95,9 @@ stdout_logfile_maxbytes=0\n\
 ARG VERSION
 ARG VERSION
 ENV USE_CUDA $VERSION
 ENV USE_CUDA $VERSION
 Add . /workspace
 Add . /workspace
+RUN cp predict_det.py /opt/conda/envs/py38/lib/python3.8/site-packages/paddleocr/tools/infer/predict_det.py
+RUN cp utility.py /opt/conda/envs/py38/lib/python3.8/site-packages/paddleocr/tools/infer/utility.py
+
 EXPOSE 8080
 EXPOSE 8080
 
 
 
 

+ 7 - 7
docker-compose.yml

@@ -10,12 +10,12 @@ services:
     tty: true
     tty: true
     working_dir: /workspace
     working_dir: /workspace
     ports:
     ports:
-      - '18081:8080'
-      - '18223:22'
+      - '28811:8080'
+      - '222:22'
     volumes:
     volumes:
       - ./:/workspace
       - ./:/workspace
-    deploy:
-      resources:
-        reservations:
-          devices:
-            - capabilities: [gpu]
+#    deploy:
+#      resources:
+#        reservations:
+#          devices:
+#            - capabilities: [gpu]

+ 95 - 0
en_dict.txt

@@ -0,0 +1,95 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+ 

+ 1 - 1
ppocr_keys_bank.txt

@@ -7,4 +7,4 @@
 6
 6
 7
 7
 8
 8
-9
+9

+ 303 - 0
predict_det.py

@@ -0,0 +1,303 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+import sys
+
+import tools.infer.utility as utility
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+import json
+logger = get_logger()
+
+
+class TextDetector(object):
+    def __init__(self, args):
+        self.args = args
+        self.det_algorithm = args.det_algorithm
+        self.use_onnx = args.use_onnx
+        pre_process_list = [{
+            'DetResizeForTest': {
+                # 'limit_side_len': args.det_limit_side_len,
+                # 'limit_type': args.det_limit_type,
+                'resize_long': args.det_resize_long
+            }
+        }, {
+            'NormalizeImage': {
+                'std': [0.229, 0.224, 0.225],
+                'mean': [0.485, 0.456, 0.406],
+                'scale': '1./255.',
+                'order': 'hwc'
+            }
+        }, {
+            'ToCHWImage': None
+        }, {
+            'KeepKeys': {
+                'keep_keys': ['image', 'shape']
+            }
+        }]
+        postprocess_params = {}
+        if self.det_algorithm == "DB":
+            postprocess_params['name'] = 'DBPostProcess'
+            postprocess_params["thresh"] = args.det_db_thresh
+            postprocess_params["box_thresh"] = args.det_db_box_thresh
+            postprocess_params["max_candidates"] = 1000
+            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
+            postprocess_params["use_dilation"] = args.use_dilation
+            postprocess_params["score_mode"] = args.det_db_score_mode
+        elif self.det_algorithm == "EAST":
+            postprocess_params['name'] = 'EASTPostProcess'
+            postprocess_params["score_thresh"] = args.det_east_score_thresh
+            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
+            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
+        elif self.det_algorithm == "SAST":
+            pre_process_list[0] = {
+                'DetResizeForTest': {
+                    'resize_long': args.det_limit_side_len
+                }
+            }
+            postprocess_params['name'] = 'SASTPostProcess'
+            postprocess_params["score_thresh"] = args.det_sast_score_thresh
+            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
+            self.det_sast_polygon = args.det_sast_polygon
+            if self.det_sast_polygon:
+                postprocess_params["sample_pts_num"] = 6
+                postprocess_params["expand_scale"] = 1.2
+                postprocess_params["shrink_ratio_of_width"] = 0.2
+            else:
+                postprocess_params["sample_pts_num"] = 2
+                postprocess_params["expand_scale"] = 1.0
+                postprocess_params["shrink_ratio_of_width"] = 0.3
+        elif self.det_algorithm == "PSE":
+            postprocess_params['name'] = 'PSEPostProcess'
+            postprocess_params["thresh"] = args.det_pse_thresh
+            postprocess_params["box_thresh"] = args.det_pse_box_thresh
+            postprocess_params["min_area"] = args.det_pse_min_area
+            postprocess_params["box_type"] = args.det_pse_box_type
+            postprocess_params["scale"] = args.det_pse_scale
+            self.det_pse_box_type = args.det_pse_box_type
+        elif self.det_algorithm == "FCE":
+            pre_process_list[0] = {
+                'DetResizeForTest': {
+                    'rescale_img': [1080, 736]
+                }
+            }
+            postprocess_params['name'] = 'FCEPostProcess'
+            postprocess_params["scales"] = args.scales
+            postprocess_params["alpha"] = args.alpha
+            postprocess_params["beta"] = args.beta
+            postprocess_params["fourier_degree"] = args.fourier_degree
+            postprocess_params["box_type"] = args.det_fce_box_type
+        else:
+            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
+            sys.exit(0)
+
+        self.preprocess_op = create_operators(pre_process_list)
+        self.postprocess_op = build_post_process(postprocess_params)
+        self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
+            args, 'det', logger)
+
+        if self.use_onnx:
+            img_h, img_w = self.input_tensor.shape[2:]
+            if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
+                pre_process_list[0] = {
+                    'DetResizeForTest': {
+                        'image_shape': [img_h, img_w]
+                    }
+                }
+        self.preprocess_op = create_operators(pre_process_list)
+
+        if args.benchmark:
+            import auto_log
+            pid = os.getpid()
+            gpu_id = utility.get_infer_gpuid()
+            self.autolog = auto_log.AutoLogger(
+                model_name="det",
+                model_precision=args.precision,
+                batch_size=1,
+                data_shape="dynamic",
+                save_path=None,
+                inference_config=self.config,
+                pids=pid,
+                process_name=None,
+                gpu_ids=gpu_id if args.use_gpu else None,
+                time_keys=[
+                    'preprocess_time', 'inference_time', 'postprocess_time'
+                ],
+                warmup=2,
+                logger=logger)
+
+    def order_points_clockwise(self, pts):
+        rect = np.zeros((4, 2), dtype="float32")
+        s = pts.sum(axis=1)
+        rect[0] = pts[np.argmin(s)]
+        rect[2] = pts[np.argmax(s)]
+        diff = np.diff(pts, axis=1)
+        rect[1] = pts[np.argmin(diff)]
+        rect[3] = pts[np.argmax(diff)]
+        return rect
+
+    def clip_det_res(self, points, img_height, img_width):
+        for pno in range(points.shape[0]):
+            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
+            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
+        return points
+
+    def filter_tag_det_res(self, dt_boxes, image_shape):
+        img_height, img_width = image_shape[0:2]
+        dt_boxes_new = []
+        for box in dt_boxes:
+            box = self.order_points_clockwise(box)
+            box = self.clip_det_res(box, img_height, img_width)
+            rect_width = int(np.linalg.norm(box[0] - box[1]))
+            rect_height = int(np.linalg.norm(box[0] - box[3]))
+            if rect_width <= 3 or rect_height <= 3:
+                continue
+            dt_boxes_new.append(box)
+        dt_boxes = np.array(dt_boxes_new)
+        return dt_boxes
+
+    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
+        img_height, img_width = image_shape[0:2]
+        dt_boxes_new = []
+        for box in dt_boxes:
+            box = self.clip_det_res(box, img_height, img_width)
+            dt_boxes_new.append(box)
+        dt_boxes = np.array(dt_boxes_new)
+        return dt_boxes
+
+    def __call__(self, img):
+        ori_im = img.copy()
+        data = {'image': img}
+
+        st = time.time()
+
+        if self.args.benchmark:
+            self.autolog.times.start()
+
+        data = transform(data, self.preprocess_op)
+        img, shape_list = data
+        if img is None:
+            return None, 0
+        img = np.expand_dims(img, axis=0)
+        shape_list = np.expand_dims(shape_list, axis=0)
+        img = img.copy()
+
+        if self.args.benchmark:
+            self.autolog.times.stamp()
+        if self.use_onnx:
+            input_dict = {}
+            input_dict[self.input_tensor.name] = img
+            outputs = self.predictor.run(self.output_tensors, input_dict)
+        else:
+            self.input_tensor.copy_from_cpu(img)
+            self.predictor.run()
+            outputs = []
+            for output_tensor in self.output_tensors:
+                output = output_tensor.copy_to_cpu()
+                outputs.append(output)
+            if self.args.benchmark:
+                self.autolog.times.stamp()
+
+        preds = {}
+        if self.det_algorithm == "EAST":
+            preds['f_geo'] = outputs[0]
+            preds['f_score'] = outputs[1]
+        elif self.det_algorithm == 'SAST':
+            preds['f_border'] = outputs[0]
+            preds['f_score'] = outputs[1]
+            preds['f_tco'] = outputs[2]
+            preds['f_tvo'] = outputs[3]
+        elif self.det_algorithm in ['DB', 'PSE']:
+            preds['maps'] = outputs[0]
+        elif self.det_algorithm == 'FCE':
+            for i, output in enumerate(outputs):
+                preds['level_{}'.format(i)] = output
+        else:
+            raise NotImplementedError
+
+        #self.predictor.try_shrink_memory()
+        post_result = self.postprocess_op(preds, shape_list)
+        dt_boxes = post_result[0]['points']
+        if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
+                self.det_algorithm in ["PSE", "FCE"] and
+                self.postprocess_op.box_type == 'poly'):
+            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
+        else:
+            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
+
+        if self.args.benchmark:
+            self.autolog.times.end(stamp=True)
+        et = time.time()
+        return dt_boxes, et - st
+
+
+if __name__ == "__main__":
+    args = utility.parse_args()
+    image_file_list = get_image_file_list(args.image_dir)
+    text_detector = TextDetector(args)
+    count = 0
+    total_time = 0
+    draw_img_save = "./inference_results"
+
+    if args.warmup:
+        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+        for i in range(2):
+            res = text_detector(img)
+
+    if not os.path.exists(draw_img_save):
+        os.makedirs(draw_img_save)
+    save_results = []
+    for image_file in image_file_list:
+        img, flag = check_and_read_gif(image_file)
+        if not flag:
+            img = cv2.imread(image_file)
+        if img is None:
+            logger.info("error in loading image:{}".format(image_file))
+            continue
+        st = time.time()
+        dt_boxes, _ = text_detector(img)
+        elapse = time.time() - st
+        if count > 0:
+            total_time += elapse
+        count += 1
+        save_pred = os.path.basename(image_file) + "\t" + str(
+            json.dumps([x.tolist() for x in dt_boxes])) + "\n"
+        save_results.append(save_pred)
+        logger.info(save_pred)
+        logger.info("The predict time of {}: {}".format(image_file, elapse))
+        src_im = utility.draw_text_det_res(dt_boxes, image_file)
+        img_name_pure = os.path.split(image_file)[-1]
+        img_path = os.path.join(draw_img_save,
+                                "det_res_{}".format(img_name_pure))
+        cv2.imwrite(img_path, src_im)
+        logger.info("The visualized image saved in {}".format(img_path))
+
+    with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
+        f.writelines(save_results)
+        f.close()
+    if args.benchmark:
+        text_detector.autolog.report()

+ 0 - 1
run.py

@@ -7,6 +7,5 @@ if __name__ == '__main__':
     parser.add_argument('--port', default=8080)
     parser.add_argument('--port', default=8080)
     opt = parser.parse_args()
     opt = parser.parse_args()
 
 
-
     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True, log_level='debug', workers=1)
     uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True, log_level='debug', workers=1)

+ 27 - 9
server.py

@@ -34,21 +34,39 @@ print(f'use gpu: {use_gpu}')
 #                 rec_model_dir="./bank_rec_infer/",
 #                 rec_model_dir="./bank_rec_infer/",
 #                 det_model_dir="./bank_det_infer/",
 #                 det_model_dir="./bank_det_infer/",
 #                 cls_model_dir="./bank_cls_infer/",
 #                 cls_model_dir="./bank_cls_infer/",
-#                 rec_algorithm='CRNN',
-#                 # rec_image_shape='3, 32, 320',
-#                 ocr_version='PP-OCRv2',
-#                 rec_char_dict_path="./ppocr_keys_v1.txt",
+#                 # rec_algorithm='SVTR_LCNet',
+#                 rec_image_shape='3, 48, 320',
+#                 # ocr_version='PP-OCRv2',
+#                 rec_char_dict_path="./ppocr_keys_bank.txt",
 #                 use_gpu=use_gpu,
 #                 use_gpu=use_gpu,
 #                 save_crop_res=True,
 #                 save_crop_res=True,
 #                 warmup=True)
 #                 warmup=True)
 
 
+
 ocr = PaddleOCR(use_angle_cls=True,
 ocr = PaddleOCR(use_angle_cls=True,
                 use_gpu=use_gpu,
                 use_gpu=use_gpu,
                 det_db_unclip_ratio=2.5,
                 det_db_unclip_ratio=2.5,
-                det_db_thresh=0.1,
-                det_db_box_thresh=0.4,
-                # save_crop_res=True,
-                warmup=True)
+                det_db_thresh=0.3,
+                det_db_box_thresh=0.6,
+                det_model_dir="./bank_det_infer/",
+                save_crop_res=True,
+                rec_model_dir="./bank_rec_infer/",
+                rec_char_dict_path="./ppocr_keys_bank.txt",
+                use_space_char=False,
+                warmup=True
+                )
+
+
+origin_ocr = PaddleOCR(use_angle_cls=True,
+                       use_gpu=use_gpu,
+                       det_db_unclip_ratio=2.5,
+                       det_db_thresh=0.3,
+                       det_db_box_thresh=0.6,
+                       rec_model_dir="./bank_rec_infer/",
+                       rec_char_dict_path="./ppocr_keys_bank.txt",
+                       use_space_char=False,
+                       warmup=True
+                       )
 
 
 
 
 # ocr = PaddleOCR(use_angle_cls=True,
 # ocr = PaddleOCR(use_angle_cls=True,
@@ -64,7 +82,7 @@ ocr = PaddleOCR(use_angle_cls=True,
 #                 warmup=True)
 #                 warmup=True)
 
 
 
 
-ad = AngleDetector(ocr)
+ad = AngleDetector(origin_ocr)
 m = BankOcr(ocr, ad)
 m = BankOcr(ocr, ad)
 
 
 
 

BIN
simfang.ttf


+ 0 - 0
testing/101_error_test.py → testing/error_101_test.py


+ 2 - 1
testing/utils.py

@@ -3,13 +3,14 @@ import base64
 import requests
 import requests
 
 
 
 
-url = 'http://192.168.13.54:18081'
+url = 'http://192.168.199.249:2991'
 
 
 
 
 # url = 'http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm'
 # url = 'http://aihub-test.digitalyili.com/aiSquare/openApi/reasoning-services/rlocrxm'
 # header = {
 # header = {
 #     'Authorization': 'Bearer 9679c2b3-b90b-4029-a3c7-f347b4d242f7'
 #     'Authorization': 'Bearer 9679c2b3-b90b-4029-a3c7-f347b4d242f7'
 # }
 # }
+
 def send_request(image_path):
 def send_request(image_path):
     with open(image_path, 'rb') as f:
     with open(image_path, 'rb') as f:
         img_str: str = base64.encodebytes(f.read()).decode('utf-8')
         img_str: str = base64.encodebytes(f.read()).decode('utf-8')

+ 646 - 0
utility.py

@@ -0,0 +1,646 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import sys
+import platform
+import cv2
+import numpy as np
+import paddle
+from PIL import Image, ImageDraw, ImageFont
+import math
+from paddle import inference
+import time
+from ppocr.utils.logging import get_logger
+
+
+def str2bool(v):
+    return v.lower() in ("true", "t", "1")
+
+
+def init_args():
+    parser = argparse.ArgumentParser()
+    # params for prediction engine
+    parser.add_argument("--use_gpu", type=str2bool, default=True)
+    parser.add_argument("--ir_optim", type=str2bool, default=True)
+    parser.add_argument("--use_tensorrt", type=str2bool, default=False)
+    parser.add_argument("--min_subgraph_size", type=int, default=15)
+    parser.add_argument("--precision", type=str, default="fp32")
+    parser.add_argument("--gpu_mem", type=int, default=500)
+
+    # params for text detector
+    parser.add_argument("--image_dir", type=str)
+    parser.add_argument("--det_algorithm", type=str, default='DB')
+    parser.add_argument("--det_model_dir", type=str)
+    parser.add_argument("--det_resize_long", type=float, default=960)
+    parser.add_argument("--det_limit_side_len", type=float, default=960)
+    parser.add_argument("--det_limit_type", type=str, default='max')
+
+    # DB parmas
+    parser.add_argument("--det_db_thresh", type=float, default=0.3)
+    parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
+    parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
+    parser.add_argument("--max_batch_size", type=int, default=10)
+    parser.add_argument("--use_dilation", type=str2bool, default=False)
+    parser.add_argument("--det_db_score_mode", type=str, default="fast")
+    # EAST parmas
+    parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
+    parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
+    parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
+
+    # SAST parmas
+    parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
+    parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
+    parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
+
+    # PSE parmas
+    parser.add_argument("--det_pse_thresh", type=float, default=0)
+    parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
+    parser.add_argument("--det_pse_min_area", type=float, default=16)
+    parser.add_argument("--det_pse_box_type", type=str, default='quad')
+    parser.add_argument("--det_pse_scale", type=int, default=1)
+
+    # FCE parmas
+    parser.add_argument("--scales", type=list, default=[8, 16, 32])
+    parser.add_argument("--alpha", type=float, default=1.0)
+    parser.add_argument("--beta", type=float, default=1.0)
+    parser.add_argument("--fourier_degree", type=int, default=5)
+    parser.add_argument("--det_fce_box_type", type=str, default='poly')
+
+    # params for text recognizer
+    parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
+    parser.add_argument("--rec_model_dir", type=str)
+    parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
+    parser.add_argument("--rec_batch_num", type=int, default=6)
+    parser.add_argument("--max_text_length", type=int, default=25)
+    parser.add_argument(
+        "--rec_char_dict_path",
+        type=str,
+        default="./ppocr/utils/ppocr_keys_v1.txt")
+    parser.add_argument("--use_space_char", type=str2bool, default=True)
+    parser.add_argument(
+        "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
+    parser.add_argument("--drop_score", type=float, default=0.5)
+
+    # params for e2e
+    parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
+    parser.add_argument("--e2e_model_dir", type=str)
+    parser.add_argument("--e2e_limit_side_len", type=float, default=768)
+    parser.add_argument("--e2e_limit_type", type=str, default='max')
+
+    # PGNet parmas
+    parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
+    parser.add_argument(
+        "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
+    parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
+    parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
+
+    # params for text classifier
+    parser.add_argument("--use_angle_cls", type=str2bool, default=False)
+    parser.add_argument("--cls_model_dir", type=str)
+    parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
+    parser.add_argument("--label_list", type=list, default=['0', '180'])
+    parser.add_argument("--cls_batch_num", type=int, default=6)
+    parser.add_argument("--cls_thresh", type=float, default=0.9)
+
+    parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
+    parser.add_argument("--cpu_threads", type=int, default=10)
+    parser.add_argument("--use_pdserving", type=str2bool, default=False)
+    parser.add_argument("--warmup", type=str2bool, default=False)
+
+    #
+    parser.add_argument(
+        "--draw_img_save_dir", type=str, default="./inference_results")
+    parser.add_argument("--save_crop_res", type=str2bool, default=False)
+    parser.add_argument("--crop_res_save_dir", type=str, default="./output")
+
+    # multi-process
+    parser.add_argument("--use_mp", type=str2bool, default=False)
+    parser.add_argument("--total_process_num", type=int, default=1)
+    parser.add_argument("--process_id", type=int, default=0)
+
+    parser.add_argument("--benchmark", type=str2bool, default=False)
+    parser.add_argument("--save_log_path", type=str, default="./log_output/")
+
+    parser.add_argument("--show_log", type=str2bool, default=True)
+    parser.add_argument("--use_onnx", type=str2bool, default=False)
+    return parser
+
+
+def parse_args():
+    parser = init_args()
+    return parser.parse_args()
+
+
+def create_predictor(args, mode, logger):
+    if mode == "det":
+        model_dir = args.det_model_dir
+    elif mode == 'cls':
+        model_dir = args.cls_model_dir
+    elif mode == 'rec':
+        model_dir = args.rec_model_dir
+    elif mode == 'table':
+        model_dir = args.table_model_dir
+    else:
+        model_dir = args.e2e_model_dir
+
+    if model_dir is None:
+        logger.info("not find {} model file path {}".format(mode, model_dir))
+        sys.exit(0)
+    if args.use_onnx:
+        import onnxruntime as ort
+        model_file_path = model_dir
+        if not os.path.exists(model_file_path):
+            raise ValueError("not find model file path {}".format(
+                model_file_path))
+        sess = ort.InferenceSession(model_file_path)
+        return sess, sess.get_inputs()[0], None, None
+
+    else:
+        model_file_path = model_dir + "/inference.pdmodel"
+        params_file_path = model_dir + "/inference.pdiparams"
+        if not os.path.exists(model_file_path):
+            raise ValueError("not find model file path {}".format(
+                model_file_path))
+        if not os.path.exists(params_file_path):
+            raise ValueError("not find params file path {}".format(
+                params_file_path))
+
+        config = inference.Config(model_file_path, params_file_path)
+
+        if hasattr(args, 'precision'):
+            if args.precision == "fp16" and args.use_tensorrt:
+                precision = inference.PrecisionType.Half
+            elif args.precision == "int8":
+                precision = inference.PrecisionType.Int8
+            else:
+                precision = inference.PrecisionType.Float32
+        else:
+            precision = inference.PrecisionType.Float32
+
+        if args.use_gpu:
+            gpu_id = get_infer_gpuid()
+            if gpu_id is None:
+                logger.warning(
+                    "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
+                )
+            config.enable_use_gpu(args.gpu_mem, 0)
+            if args.use_tensorrt:
+                config.enable_tensorrt_engine(
+                    workspace_size=1 << 30,
+                    precision_mode=precision,
+                    max_batch_size=args.max_batch_size,
+                    min_subgraph_size=args.min_subgraph_size)
+                # skip the minmum trt subgraph
+            use_dynamic_shape = True
+            if mode == "det":
+                min_input_shape = {
+                    "x": [1, 3, 50, 50],
+                    "conv2d_92.tmp_0": [1, 120, 20, 20],
+                    "conv2d_91.tmp_0": [1, 24, 10, 10],
+                    "conv2d_59.tmp_0": [1, 96, 20, 20],
+                    "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
+                    "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
+                    "conv2d_124.tmp_0": [1, 256, 20, 20],
+                    "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
+                    "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
+                    "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
+                    "elementwise_add_7": [1, 56, 2, 2],
+                    "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
+                }
+                max_input_shape = {
+                    "x": [1, 3, 1536, 1536],
+                    "conv2d_92.tmp_0": [1, 120, 400, 400],
+                    "conv2d_91.tmp_0": [1, 24, 200, 200],
+                    "conv2d_59.tmp_0": [1, 96, 400, 400],
+                    "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
+                    "conv2d_124.tmp_0": [1, 256, 400, 400],
+                    "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
+                    "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
+                    "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
+                    "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
+                    "elementwise_add_7": [1, 56, 400, 400],
+                    "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
+                }
+                opt_input_shape = {
+                    "x": [1, 3, 640, 640],
+                    "conv2d_92.tmp_0": [1, 120, 160, 160],
+                    "conv2d_91.tmp_0": [1, 24, 80, 80],
+                    "conv2d_59.tmp_0": [1, 96, 160, 160],
+                    "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
+                    "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
+                    "conv2d_124.tmp_0": [1, 256, 160, 160],
+                    "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
+                    "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
+                    "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
+                    "elementwise_add_7": [1, 56, 40, 40],
+                    "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
+                }
+                min_pact_shape = {
+                    "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
+                    "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
+                    "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
+                    "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
+                }
+                max_pact_shape = {
+                    "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
+                    "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
+                    "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
+                    "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
+                }
+                opt_pact_shape = {
+                    "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
+                    "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
+                    "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
+                    "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
+                }
+                min_input_shape.update(min_pact_shape)
+                max_input_shape.update(max_pact_shape)
+                opt_input_shape.update(opt_pact_shape)
+            elif mode == "rec":
+                if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
+                    use_dynamic_shape = False
+                imgH = int(args.rec_image_shape.split(',')[-2])
+                min_input_shape = {"x": [1, 3, imgH, 10]}
+                max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
+                opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
+            elif mode == "cls":
+                min_input_shape = {"x": [1, 3, 48, 10]}
+                max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
+                opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
+            else:
+                use_dynamic_shape = False
+            if use_dynamic_shape:
+                config.set_trt_dynamic_shape_info(
+                    min_input_shape, max_input_shape, opt_input_shape)
+
+        else:
+            config.disable_gpu()
+            if hasattr(args, "cpu_threads"):
+                config.set_cpu_math_library_num_threads(args.cpu_threads)
+            else:
+                # default cpu threads as 10
+                config.set_cpu_math_library_num_threads(10)
+            if args.enable_mkldnn:
+                # cache 10 different shapes for mkldnn to avoid memory leak
+                config.set_mkldnn_cache_capacity(10)
+                config.enable_mkldnn()
+                if args.precision == "fp16":
+                    config.enable_mkldnn_bfloat16()
+        # enable memory optim
+        config.enable_memory_optim()
+        config.disable_glog_info()
+        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
+        config.delete_pass("matmul_transpose_reshape_fuse_pass")
+        if mode == 'table':
+            config.delete_pass("fc_fuse_pass")  # not supported for table
+        config.switch_use_feed_fetch_ops(False)
+        config.switch_ir_optim(True)
+
+        # create predictor
+        predictor = inference.create_predictor(config)
+        input_names = predictor.get_input_names()
+        for name in input_names:
+            input_tensor = predictor.get_input_handle(name)
+        output_tensors = get_output_tensors(args, mode, predictor)
+        return predictor, input_tensor, output_tensors, config
+
+
+def get_output_tensors(args, mode, predictor):
+    output_names = predictor.get_output_names()
+    output_tensors = []
+    if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
+        output_name = 'softmax_0.tmp_0'
+        if output_name in output_names:
+            return [predictor.get_output_handle(output_name)]
+        else:
+            for output_name in output_names:
+                output_tensor = predictor.get_output_handle(output_name)
+                output_tensors.append(output_tensor)
+    else:
+        for output_name in output_names:
+            output_tensor = predictor.get_output_handle(output_name)
+            output_tensors.append(output_tensor)
+    return output_tensors
+
+
+def get_infer_gpuid():
+    sysstr = platform.system()
+    if sysstr == "Windows":
+        return 0
+
+    if not paddle.fluid.core.is_compiled_with_rocm():
+        cmd = "env | grep CUDA_VISIBLE_DEVICES"
+    else:
+        cmd = "env | grep HIP_VISIBLE_DEVICES"
+    env_cuda = os.popen(cmd).readlines()
+    if len(env_cuda) == 0:
+        return 0
+    else:
+        gpu_id = env_cuda[0].strip().split("=")[1]
+        return int(gpu_id[0])
+
+
+def draw_e2e_res(dt_boxes, strs, img_path):
+    src_im = cv2.imread(img_path)
+    for box, str in zip(dt_boxes, strs):
+        box = box.astype(np.int32).reshape((-1, 1, 2))
+        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+        cv2.putText(
+            src_im,
+            str,
+            org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
+            fontFace=cv2.FONT_HERSHEY_COMPLEX,
+            fontScale=0.7,
+            color=(0, 255, 0),
+            thickness=1)
+    return src_im
+
+
+def draw_text_det_res(dt_boxes, img_path):
+    src_im = cv2.imread(img_path)
+    for box in dt_boxes:
+        box = np.array(box).astype(np.int32).reshape(-1, 2)
+        cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+    return src_im
+
+
+def resize_img(img, input_size=600):
+    """
+    resize img and limit the longest side of the image to input_size
+    """
+    img = np.array(img)
+    im_shape = img.shape
+    im_size_max = np.max(im_shape[0:2])
+    im_scale = float(input_size) / float(im_size_max)
+    img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
+    return img
+
+
+def draw_ocr(image,
+             boxes,
+             txts=None,
+             scores=None,
+             drop_score=0.5,
+             font_path="./doc/fonts/simfang.ttf"):
+    """
+    Visualize the results of OCR detection and recognition
+    args:
+        image(Image|array): RGB image
+        boxes(list): boxes with shape(N, 4, 2)
+        txts(list): the texts
+        scores(list): txxs corresponding scores
+        drop_score(float): only scores greater than drop_threshold will be visualized
+        font_path: the path of font which is used to draw text
+    return(array):
+        the visualized img
+    """
+    if scores is None:
+        scores = [1] * len(boxes)
+    box_num = len(boxes)
+    for i in range(box_num):
+        if scores is not None and (scores[i] < drop_score or
+                                   math.isnan(scores[i])):
+            continue
+        box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
+        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
+    if txts is not None:
+        img = np.array(resize_img(image, input_size=600))
+        txt_img = text_visual(
+            txts,
+            scores,
+            img_h=img.shape[0],
+            img_w=600,
+            threshold=drop_score,
+            font_path=font_path)
+        img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
+        return img
+    return image
+
+
+def draw_ocr_box_txt(image,
+                     boxes,
+                     txts,
+                     scores=None,
+                     drop_score=0.5,
+                     font_path="./doc/simfang.ttf"):
+    h, w = image.height, image.width
+    img_left = image.copy()
+    img_right = Image.new('RGB', (w, h), (255, 255, 255))
+
+    import random
+
+    random.seed(0)
+    draw_left = ImageDraw.Draw(img_left)
+    draw_right = ImageDraw.Draw(img_right)
+    for idx, (box, txt) in enumerate(zip(boxes, txts)):
+        if scores is not None and scores[idx] < drop_score:
+            continue
+        color = (random.randint(0, 255), random.randint(0, 255),
+                 random.randint(0, 255))
+        draw_left.polygon(box, fill=color)
+        draw_right.polygon(
+            [
+                box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
+                box[2][1], box[3][0], box[3][1]
+            ],
+            outline=color)
+        box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
+            1])**2)
+        box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
+            1])**2)
+        if box_height > 2 * box_width:
+            font_size = max(int(box_width * 0.9), 10)
+            font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+            cur_y = box[0][1]
+            for c in txt:
+                char_size = font.getsize(c)
+                draw_right.text(
+                    (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
+                cur_y += char_size[1]
+        else:
+            font_size = max(int(box_height * 0.8), 10)
+            font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+            draw_right.text(
+                [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
+    img_left = Image.blend(image, img_left, 0.5)
+    img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
+    img_show.paste(img_left, (0, 0, w, h))
+    img_show.paste(img_right, (w, 0, w * 2, h))
+    return np.array(img_show)
+
+
+def str_count(s):
+    """
+    Count the number of Chinese characters,
+    a single English character and a single number
+    equal to half the length of Chinese characters.
+    args:
+        s(string): the input of string
+    return(int):
+        the number of Chinese characters
+    """
+    import string
+    count_zh = count_pu = 0
+    s_len = len(s)
+    en_dg_count = 0
+    for c in s:
+        if c in string.ascii_letters or c.isdigit() or c.isspace():
+            en_dg_count += 1
+        elif c.isalpha():
+            count_zh += 1
+        else:
+            count_pu += 1
+    return s_len - math.ceil(en_dg_count / 2)
+
+
+def text_visual(texts,
+                scores,
+                img_h=400,
+                img_w=600,
+                threshold=0.,
+                font_path="./doc/simfang.ttf"):
+    """
+    create new blank img and draw txt on it
+    args:
+        texts(list): the text will be draw
+        scores(list|None): corresponding score of each txt
+        img_h(int): the height of blank img
+        img_w(int): the width of blank img
+        font_path: the path of font which is used to draw text
+    return(array):
+    """
+    if scores is not None:
+        assert len(texts) == len(
+            scores), "The number of txts and corresponding scores must match"
+
+    def create_blank_img():
+        blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
+        blank_img[:, img_w - 1:] = 0
+        blank_img = Image.fromarray(blank_img).convert("RGB")
+        draw_txt = ImageDraw.Draw(blank_img)
+        return blank_img, draw_txt
+
+    blank_img, draw_txt = create_blank_img()
+
+    font_size = 20
+    txt_color = (0, 0, 0)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+
+    gap = font_size + 5
+    txt_img_list = []
+    count, index = 1, 0
+    for idx, txt in enumerate(texts):
+        index += 1
+        if scores[idx] < threshold or math.isnan(scores[idx]):
+            index -= 1
+            continue
+        first_line = True
+        while str_count(txt) >= img_w // font_size - 4:
+            tmp = txt
+            txt = tmp[:img_w // font_size - 4]
+            if first_line:
+                new_txt = str(index) + ': ' + txt
+                first_line = False
+            else:
+                new_txt = '    ' + txt
+            draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
+            txt = tmp[img_w // font_size - 4:]
+            if count >= img_h // gap - 1:
+                txt_img_list.append(np.array(blank_img))
+                blank_img, draw_txt = create_blank_img()
+                count = 0
+            count += 1
+        if first_line:
+            new_txt = str(index) + ': ' + txt + '   ' + '%.3f' % (scores[idx])
+        else:
+            new_txt = "  " + txt + "  " + '%.3f' % (scores[idx])
+        draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
+        # whether add new blank img or not
+        if count >= img_h // gap - 1 and idx + 1 < len(texts):
+            txt_img_list.append(np.array(blank_img))
+            blank_img, draw_txt = create_blank_img()
+            count = 0
+        count += 1
+    txt_img_list.append(np.array(blank_img))
+    if len(txt_img_list) == 1:
+        blank_img = np.array(txt_img_list[0])
+    else:
+        blank_img = np.concatenate(txt_img_list, axis=1)
+    return np.array(blank_img)
+
+
+def base64_to_cv2(b64str):
+    import base64
+    data = base64.b64decode(b64str.encode('utf8'))
+    data = np.fromstring(data, np.uint8)
+    data = cv2.imdecode(data, cv2.IMREAD_COLOR)
+    return data
+
+
+def draw_boxes(image, boxes, scores=None, drop_score=0.5):
+    if scores is None:
+        scores = [1] * len(boxes)
+    for (box, score) in zip(boxes, scores):
+        if score < drop_score:
+            continue
+        box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
+        image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
+    return image
+
+
+def get_rotate_crop_image(img, points):
+    '''
+    img_height, img_width = img.shape[0:2]
+    left = int(np.min(points[:, 0]))
+    right = int(np.max(points[:, 0]))
+    top = int(np.min(points[:, 1]))
+    bottom = int(np.max(points[:, 1]))
+    img_crop = img[top:bottom, left:right, :].copy()
+    points[:, 0] = points[:, 0] - left
+    points[:, 1] = points[:, 1] - top
+    '''
+    assert len(points) == 4, "shape of points must be 4*2"
+    img_crop_width = int(
+        max(
+            np.linalg.norm(points[0] - points[1]),
+            np.linalg.norm(points[2] - points[3])))
+    img_crop_height = int(
+        max(
+            np.linalg.norm(points[0] - points[3]),
+            np.linalg.norm(points[1] - points[2])))
+    pts_std = np.float32([[0, 0], [img_crop_width, 0],
+                          [img_crop_width, img_crop_height],
+                          [0, img_crop_height]])
+    M = cv2.getPerspectiveTransform(points, pts_std)
+    dst_img = cv2.warpPerspective(
+        img,
+        M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_REPLICATE,
+        flags=cv2.INTER_CUBIC)
+    dst_img_height, dst_img_width = dst_img.shape[0:2]
+    if dst_img_height * 1.0 / dst_img_width >= 1.5:
+        dst_img = np.rot90(dst_img)
+    return dst_img
+
+
+def check_gpu(use_gpu):
+    if use_gpu and not paddle.is_compiled_with_cuda():
+        use_gpu = False
+    return use_gpu
+
+
+if __name__ == '__main__':
+    pass