123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- import io
- import json
- import re
- from fastapi import FastAPI, Request, File, UploadFile, Body
- from fastapi.middleware.cors import CORSMiddleware
- from sx_utils.sximage import *
- from sx_utils.sxtime import sxtimeit
- from sx_utils.sxweb import web_try
- import requests
- from PIL import Image
- from pydantic import BaseModel
- import sys
- import logging
- import os
- import cv2
- from paddleocr import PaddleOCR
- logger = logging.getLogger('log')
- logger.setLevel(logging.DEBUG)
- # 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
- while logger.hasHandlers():
- for i in logger.handlers:
- logger.removeHandler(i)
- # file log 写入文件配置
- formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') # 日志的格式
- # 本地运行时,这部分需注释
- # fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8') # 日志文件路径文件名称,编码格式
- # fh.setLevel(logging.DEBUG) # 日志打印级别
- # fh.setFormatter(formatter)
- # logger.addHandler(fh)
- # console log 控制台输出控制
- ch = logging.StreamHandler(sys.stdout)
- ch.setLevel(logging.DEBUG)
- ch.setFormatter(formatter)
- logger.addHandler(ch)
- app = FastAPI()
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- use_gpu = False
- if os.getenv('USE_CUDA') == 'gpu':
- use_gpu = True
- logger.info(f"->是否使用GPU:{use_gpu}")
- ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./table_rec_infer/",det_model_dir="./table_det_infer/",cls_model_dir="table_cls_infer",lang="ch")
- @app.get("/ping")
- def ping():
- return "pong!"
- class ImageListInfo(BaseModel):
- images: list
- img_type: str
- @app.post("/ocr_system/paddle")
- @sxtimeit
- @web_try()
- def rotate_bound_white_bg(self, image, angle):
-
- (h, w) = image.shape[:2]
- (cX, cY) = (w // 2, h // 2)
- M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
- cos = np.abs(M[0, 0])
- sin = np.abs(M[0, 1])
- nW = int((h * sin) + (w * cos))
- nH = int((h * cos) + (w * sin))
- M[0, 2] += (nW / 2) - cX
- M[1, 2] += (nH / 2) - cY
- return cv2.warpAffine(image, M, (nW, nH), borderValue=(255, 255, 255))
- class GetImageRotation(object):
- def __init__(self):
- self.ocr = PaddleOCR(use_angle_cls=True)
- self.ocr_angle = PaddleOCR(use_angle_cls=True)
- def get_real_rotation_when_null_rect(self, rect_list):
- w_div_h_sum = 0
- count = 0
- for rect in rect_list:
- p0 = rect[0]
- p1 = rect[1]
- p2 = rect[2]
- p3 = rect[3]
- width = abs(p1[0] - p0[0])
- height = abs(p3[1] - p0[1])
- w_div_h = width / height
- if abs(w_div_h - 1.0) < 0.5:
- count += 1
- continue
- w_div_h_sum += w_div_h
- length = len(rect_list) - count
- if length == 0:
- length = 1
- if w_div_h_sum / length >= 1.5:
- return 1
- else:
- return 0
-
- def get_real_rotation_flag(self, rect_list):
- ret_rect = []
- w_div_h_mean = 0
- real_rect_count = 0
- rect_big_list = []
- rect_small_list = []
- w_div_h_sum_big = []
- w_div_h_sum_small = []
- for rect in rect_list:
- p0 = rect[0]
- p1 = rect[1]
- p2 = rect[2]
- p3 = rect[3]
- width = abs(p1[0] - p0[0])
- height = abs(p3[1] - p0[1])
- w_div_h = width / height
- if 5 <= w_div_h <= 25:
- real_rect_count +=1
- rect_big_list.append(rect)
- w_div_h_sum_big.append(w_div_h)
-
- if 0.04 <= w_div_h <= 0.2:
- real_rect_count -=1
- rect_small_list.append(rect)
- w_div_h_sum_small.append(w_div_h)
- if real_rect_count > 0:
- ret_rect = rect_big_list
- w_div_h_mean = np.mean(w_div_h_sum_big)
- else:
- ret_rect = rect_small_list
- w_div_h_mean = np.mean(w_div_h_sum_small)
-
- if w_div_h_mean >= 1.5:
- return 1, ret_rect
- else:
- return 0, ret_rect
- def crop_image(self, rect, image):
- p0 = rect[0]
- p1 = rect[1]
- p2 = rect[2]
- p3 = rect[3]
- crop = image[int(p0[1]):int(p2[1]), int(p0[0]):int(p2[0])]
- # crop_image = Image.fromarray(crop)
- return crop
- def get_img_real_angle(self, img):
- ret_angle = 0
- image = img
- # ocr = PaddleOCR(use_angle_cls=True)
- # angle_cls = ocr.ocr(img_path, det=False, rec=False, cls=True)
- rect_list = self.ocr.ocr(image, rec=False)
- if rect_list != [[]]:
- except_flag = False
- try:
- real_angle_flag, rect_good = self.get_real_rotation_flag(
- rect_list)
- rect_crop = choice(rect_good)
- # rect_crop = rect_good[0]
- image_crop = self.crop_image(rect_crop, image)
- # ocr_angle = PaddleOCR(use_angle_cls=True)
- angle_cls = self.ocr_angle.ocr(
- image_crop, det=False, rec=False, cls=True)
- except:
- except_flag = True
- real_angle_flag = self.get_real_rotation_when_null_rect(
- rect_list)
- # ocr_angle = PaddleOCR(use_angle_cls=True)
- angle_cls = self.ocr_angle.ocr(
- image, det=False, rec=False, cls=True)
- else:
- return 0
- if angle_cls[0][0] == '0':
- if real_angle_flag:
- ret_angle = 0
- else:
- ret_angle = 270
- if not except_flag:
- anticlockwise_90 = rotate_bound_white_bg(image_crop, 90)
- angle_cls = self.ocr_angle.ocr(anticlockwise_90, det=False, rec=False, cls=True)
- if angle_cls[0][0] == '0':
- ret_angle = 270
- if angle_cls[0][0] == '180':
- ret_angle = 90
- if angle_cls[0][0] == '180':
- if real_angle_flag:
- ret_angle = 180
- else:
- ret_angle = 90
- return ret_angle
- def paddle(request: Request,info: ImageListInfo):
- logger.info(f"->图片数量:{len(info.images)}")
- res_list = []
- for b_img in info.images:
- img = base64_to_np(b_img)
- route=GetImageRotation()
- route2=route.get_img_real_angle(img)
- if route2==90 or route2== 270:
- img=im.transpose(img.ROTATE_90)
- result=ocr.ocr(img,cls=True)
- r_list = []
- for text_list in result:
- if len(text_list) >= 1:
- data = {}
- data["confidence"]= text_list[1][1]
- data["text"] = text_list[1][0]
- data["type"] = info.img_type
- data["text_region"]= text_list[0]
- r_list.append(data)
- res_list.append(r_list)
- return res_list
- if __name__ == '__main__':
- import uvicorn
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('--host', default='0.0.0.0')
- parser.add_argument('--port', default=8080)
- opt = parser.parse_args()
- app_str = 'server:app' # make the app string equal to whatever the name of this file is
- uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
|