瀏覽代碼

update server add run

zeke-chin 2 年之前
父節點
當前提交
3371f7eca4
共有 2 個文件被更改,包括 20 次插入146 次删除
  1. 11 0
      run.py
  2. 9 146
      server.py

+ 11 - 0
run.py

@@ -0,0 +1,11 @@
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 9 - 146
server.py

@@ -57,7 +57,7 @@ logger.info(f"->是否使用GPU:{use_gpu}")
 ocr = PaddleOCR(use_angle_cls=True,
                 rec_model_dir="./table_rec_infer/",
                 det_model_dir="./table_det_infer/",
-                cls_model_dir="table_cls_infer",
+                cls_model_dir="./table_cls_infer/",
                 lang="ch")
 
 
@@ -66,161 +66,21 @@ def ping():
     return "pong!"
 
 
+
+
 class ImageListInfo(BaseModel):
     images: list
     img_type: str
 
-
 @app.post("/ocr_system/paddle")
 @sxtimeit
 @web_try()
-def rotate_bound_white_bg(self, image, angle):
-    (h, w) = image.shape[:2]
-    (cX, cY) = (w // 2, h // 2)
-
-    M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
-    cos = np.abs(M[0, 0])
-    sin = np.abs(M[0, 1])
-
-    nW = int((h * sin) + (w * cos))
-    nH = int((h * cos) + (w * sin))
-
-    M[0, 2] += (nW / 2) - cX
-    M[1, 2] += (nH / 2) - cY
-    return cv2.warpAffine(image, M, (nW, nH), borderValue=(255, 255, 255))
-
-
-class GetImageRotation(object):
-    def __init__(self):
-        self.ocr = PaddleOCR(use_angle_cls=True)
-        self.ocr_angle = PaddleOCR(use_angle_cls=True)
-
-    def get_real_rotation_when_null_rect(self, rect_list):
-        w_div_h_sum = 0
-        count = 0
-        for rect in rect_list:
-            p0 = rect[0]
-            p1 = rect[1]
-            p2 = rect[2]
-            p3 = rect[3]
-            width = abs(p1[0] - p0[0])
-            height = abs(p3[1] - p0[1])
-            w_div_h = width / height
-            if abs(w_div_h - 1.0) < 0.5:
-                count += 1
-                continue
-            w_div_h_sum += w_div_h
-        length = len(rect_list) - count
-        if length == 0:
-            length = 1
-        if w_div_h_sum / length >= 1.5:
-            return 1
-        else:
-            return 0
-
-    def get_real_rotation_flag(self, rect_list):
-        ret_rect = []
-        w_div_h_mean = 0
-        real_rect_count = 0
-        rect_big_list = []
-        rect_small_list = []
-        w_div_h_sum_big = []
-        w_div_h_sum_small = []
-        for rect in rect_list:
-            p0 = rect[0]
-            p1 = rect[1]
-            p2 = rect[2]
-            p3 = rect[3]
-            width = abs(p1[0] - p0[0])
-            height = abs(p3[1] - p0[1])
-            w_div_h = width / height
-            if 5 <= w_div_h <= 25:
-                real_rect_count += 1
-                rect_big_list.append(rect)
-                w_div_h_sum_big.append(w_div_h)
-
-            if 0.04 <= w_div_h <= 0.2:
-                real_rect_count -= 1
-                rect_small_list.append(rect)
-                w_div_h_sum_small.append(w_div_h)
-        if real_rect_count > 0:
-            ret_rect = rect_big_list
-            w_div_h_mean = np.mean(w_div_h_sum_big)
-        else:
-            ret_rect = rect_small_list
-            w_div_h_mean = np.mean(w_div_h_sum_small)
-
-        if w_div_h_mean >= 1.5:
-            return 1, ret_rect
-        else:
-            return 0, ret_rect
-
-    def crop_image(self, rect, image):
-        p0 = rect[0]
-        p1 = rect[1]
-        p2 = rect[2]
-        p3 = rect[3]
-        crop = image[int(p0[1]):int(p2[1]), int(p0[0]):int(p2[0])]
-        # crop_image = Image.fromarray(crop)
-        return crop
-
-    def get_img_real_angle(self, img):
-        ret_angle = 0
-        image = img
-        # ocr = PaddleOCR(use_angle_cls=True)
-        # angle_cls = ocr.ocr(img_path, det=False, rec=False, cls=True)
-
-        rect_list = self.ocr.ocr(image, rec=False)
-        if rect_list != [[]]:
-            except_flag = False
-            try:
-                real_angle_flag, rect_good = self.get_real_rotation_flag(
-                    rect_list)
-                rect_crop = choice(rect_good)
-                # rect_crop = rect_good[0]
-                image_crop = self.crop_image(rect_crop, image)
-                # ocr_angle = PaddleOCR(use_angle_cls=True)
-                angle_cls = self.ocr_angle.ocr(
-                    image_crop, det=False, rec=False, cls=True)
-            except:
-                except_flag = True
-                real_angle_flag = self.get_real_rotation_when_null_rect(
-                    rect_list)
-                # ocr_angle = PaddleOCR(use_angle_cls=True)
-                angle_cls = self.ocr_angle.ocr(
-                    image, det=False, rec=False, cls=True)
-        else:
-            return 0
-        if angle_cls[0][0] == '0':
-            if real_angle_flag:
-                ret_angle = 0
-            else:
-                ret_angle = 270
-                if not except_flag:
-                    anticlockwise_90 = rotate_bound_white_bg(image_crop, 90)
-                    angle_cls = self.ocr_angle.ocr(anticlockwise_90, det=False, rec=False, cls=True)
-                    if angle_cls[0][0] == '0':
-                        ret_angle = 270
-                    if angle_cls[0][0] == '180':
-                        ret_angle = 90
-        if angle_cls[0][0] == '180':
-            if real_angle_flag:
-                ret_angle = 180
-            else:
-                ret_angle = 90
-        return ret_angle
-
-
-def paddle(request: Request, info: ImageListInfo):
+def paddle(request: Request,info: ImageListInfo):
     logger.info(f"->图片数量:{len(info.images)}")
     res_list = []
     for b_img in info.images:
         img = base64_to_np(b_img)
-        route = GetImageRotation()
-        route2 = route.get_img_real_angle(img)
-        if route2 == 90 or route2 == 270:
-            img = im.transpose(img.ROTATE_90)
-        result = ocr.ocr(img, cls=True)
+        result=ocr.ocr(img,cls=True)
         r_list = []
         for text_list in result:
             if len(text_list) >= 1:
@@ -234,6 +94,8 @@ def paddle(request: Request, info: ImageListInfo):
     return res_list
 
 
+
+
 if __name__ == '__main__':
     import uvicorn
     import argparse
@@ -243,5 +105,6 @@ if __name__ == '__main__':
     parser.add_argument('--port', default=8080)
     opt = parser.parse_args()
 
+
     app_str = 'server:app'  # make the app string equal to whatever the name of this file is
-    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)