123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- # -*- coding: UTF-8 -*-
- import json
- from base64 import b64decode
- import base64
- import cv2
- import numpy as np
- from fastapi import FastAPI, Request
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- from paddleocr import PaddleOCR, PPStructure
- from sx_utils.sxweb import *
- from sx_utils.sximage import *
- import threading
- import os
- import re
- from sx_utils.sx_log import *
- import paddleclas
- from cores.post_hander import *
- from cores.check_table import *
- format_print()
- # 初始化APP
- app = FastAPI()
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- table_engine_lock = threading.Lock()
- table_engine = PPStructure(layout=False,
- table=True,
- use_gpu=True,
- show_log=True,
- use_angle_cls=True,
- # det_model_dir="models/det/det_table_v2",
- # det_model_dir="models/det/det_table_v3",
- # rec_model_dir="models/rec/rec_table_v1",
- table_model_dir="models/table/SLANet_911")
- cls_lock = threading.Lock()
- cls_model = paddleclas.PaddleClas(model_name="text_image_orientation")
- # # 普通表格
- # table_engine = PPStructure(layout=False,
- # table=True,
- # use_gpu=use_gpu,
- # show_log=True,
- # det_model_dir="models/det/det_table_v2",
- # rec_model_dir="./models/rec/rec_table_v1",
- # table_model_dir="models/table/SLANet_v2")
- #
- # # 长度较长表格
- # table_engine1 = PPStructure(layout=False,
- # table=True,
- # use_gpu=use_gpu,
- # show_log=True,
- # det_model_dir="models/det/det_table_v1",
- # rec_model_dir="./models/rec/rec_table_v1",
- # table_model_dir="./models/table/SLAnet_v1")
- #
- # # 针对某些特殊情况的补充模型
- # table_engine2 = PPStructure(layout=False,
- # table=True,
- # use_gpu=use_gpu,
- # show_log=True,
- # det_model_dir="models/det/det_table_v3",
- # rec_model_dir="./models/rec/rec_table_v1",
- # table_model_dir="./models/table/SLAnet_v1")
- #
- #
- #
- # 用于判断各个角度table的识别效果,识别的字段越多,效果越好
- def cal_html_to_chs(html):
- res = []
- rows = re.split('<tr>', html)
- for row in rows:
- row = re.split('<td>', row)
- cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
- rec_str = ''.join(cells)
- for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
- rec_str = rec_str.replace(tag, '')
- res.append(rec_str)
- rec_res = ''.join(res).replace(' ', '')
- rec_res = re.split('<tdcolspan="\w+">', rec_res)
- rec_res = ''.join(rec_res).replace(' ', '')
- print(rec_res)
- return len(rec_res)
- def predict_cls(image, conf=0.8):
- try:
- cls_lock.acquire()
- result = cls_model.predict(image)
- finally:
- cls_lock.release()
- for res in result:
- score = res[0]['scores'][0]
- label_name = res[0]['label_names'][0]
- print(f"score: {score}, label_name: {label_name}")
- # print(conf)
- if score > conf:
- return int(label_name)
- return -1
- def rotate_to_zero(image, current_degree):
- # cv2.imwrite('1.jpg', image)
- current_degree = current_degree // 90
- if current_degree == 0:
- return image
- to_rotate = (4 - current_degree) - 1
- image = cv2.rotate(image, to_rotate)
- # cv2.imwrite('2.jpg', image)
- return image
- def get_zero_degree_image(img):
- step = 0.2
- for idx, i in enumerate([-1, 0, 1, 2]):
- if i >= 0:
- image = cv2.rotate(img.copy(), i)
- else:
- image = img.copy()
- conf = 0.8 - (idx * step)
- current_degree = predict_cls(image, conf) # 0 90 180 270 -1 识别不出来
- if current_degree != -1:
- img = rotate_to_zero(image, current_degree)
- break
- else:
- continue
- return img
- def table_res(im, ROTATE=-1):
- im = im.copy()
- # cv2.imwrite('before-rotate.jpg', im)
- # 获取正向图片
- img = get_zero_degree_image(im)
- # cv2.imwrite('after-rotate.jpg', img)
- try:
- table_engine_lock.acquire()
- res = table_engine(img)
- finally:
- table_engine_lock.release()
- html = res[0]['res']['html']
- return res, html
- class TableInfo(BaseModel):
- image: str
- det: str
- @app.get("/ping")
- def ping():
- return 'pong!!!!!!!!!'
- @app.post("/ocr_system/table")
- @web_try()
- def table(image: TableInfo):
- img = base64_to_np(image.image)
- res, html = table_res(img)
- # print(html)
- table = Table(html,img)
- if table.check_html():
- res, html = table_res(table.img)
- if html:
- post_hander = PostHandler(html)
- # print(post_hander.format_predict_html)
- return {'html': post_hander.format_predict_html}
- else:
- raise Exception('无法识别')
- print('table system init success!')
|