Jelajahi Sumber

feat: add 829新模型

Zhang Li 1 tahun lalu
induk
melakukan
f89314d16d

+ 16 - 16
Dockerfile

@@ -60,24 +60,24 @@ RUN ln /opt/conda/bin/mamba /usr/local/bin/mamba && mamba init zsh
 
 FROM builder1 as builder2
 
-# RUN apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*
-# RUN mkdir /var/run/sshd
-# RUN echo 'root:root' | chpasswd
-# RUN sed -i 's/.*PermitRootLogin .*/PermitRootLogin yes/' /etc/ssh/sshd_config
-# # SSH login fix. Otherwise user is kicked off after login
-# RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
+RUN apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*
+RUN mkdir /var/run/sshd
+RUN echo 'root:root' | chpasswd
+RUN sed -i 's/.*PermitRootLogin .*/PermitRootLogin yes/' /etc/ssh/sshd_config
+# SSH login fix. Otherwise user is kicked off after login
+RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
 
-# RUN echo "\
-# [program:sshd] \n\
-# command=/usr/sbin/sshd -D\n\
-# autorestart=True\n\
-# autostart=True\n\
-# redirect_stderr = true\n\
-# " > /etc/supervisor/conf.d/sshd.conf
+RUN echo "\
+[program:sshd] \n\
+command=/usr/sbin/sshd -D\n\
+autorestart=True\n\
+autostart=True\n\
+redirect_stderr = true\n\
+" > /etc/supervisor/conf.d/sshd.conf
 
-# EXPOSE 22
+EXPOSE 22
 
-# FROM builder2 as builder3
+FROM builder2 as builder3
 
 WORKDIR /workspace
 ADD environment.yml /environment.yml
@@ -89,7 +89,7 @@ RUN mamba update -n base -c defaults conda -y && mamba env create -f /environmen
 RUN echo "\
 [program:be]\n\
 directory=/workspace\n\
-command=/opt/conda/envs/py38/bin/gunicorn server:app --workers 1 --worker-class=uvicorn.workers.UvicornWorker  --bind 0.0.0.0:8080 --reload \n\
+command=/opt/conda/envs/py38/bin/python run.py \n\
 autorestart=true\n\
 startretries=0\n\
 redirect_stderr=true\n\

+ 9 - 0
Makefile

@@ -15,3 +15,12 @@ gpu:
 cpu:
 	@docker build -f cpu.Dockerfile -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu --build-arg VERSION=cpu .
 	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu
+
+image: rsync
+	@ssh sxkj@192.168.199.107 -t 'cd /home/sxkj/zhangli/ocr-table && docker build -t SXKJ:32775/$(NAME):$(VERSION) .'
+
+deploy: rsync
+	@ssh sxkj@192.168.199.107 -t 'cd /home/sxkj/zhangli/ocr-table && docker-compose down && docker-compose up -d'
+
+rsync:
+	@rsync -azP --exclude ".*/"  --exclude "tmp/" `pwd` sxkj@192.168.199.107:/home/sxkj/zhangli

+ 3 - 3
docker-compose.yml

@@ -4,7 +4,7 @@ services:
     hostname: table
     container_name: table
     restart: always
-    image: registry.cn-hangzhou.aliyuncs.com/sxtest/table:gpu
+    image: SXKJ:32775/table:latest
     privileged: true
     ipc: host
     tty: true
@@ -12,8 +12,8 @@ services:
     ports:
       - '18099:8080'
       - '18555:22'
-    # volumes:
-    #   - ./:/workspace
+    volumes:
+      - ./:/workspace
 #    deploy:
 #      resources:
 #        reservations:

+ 1 - 1
environment.yml

@@ -20,7 +20,7 @@ dependencies:
       - python-multipart
       - requests
       - cpca
-      - gunicorn
+      - uvicorn
       - protobuf==3.20.1
       - -i https://mirror.baidu.com/pypi/simple
       - paddlepaddle  # gpu==2.3.0.post110

TEMPAT SAMPAH
models/table/SLANet_829/inference.pdiparams


TEMPAT SAMPAH
models/table/SLANet_829/inference.pdiparams.info


TEMPAT SAMPAH
models/table/SLANet_829/inference.pdmodel


+ 92 - 48
server.py

@@ -9,8 +9,13 @@ from pydantic import BaseModel
 from paddleocr import PaddleOCR, PPStructure
 from sx_utils.sxweb import *
 from sx_utils.sximage import *
-
+import threading
 import os
+import re
+
+table_engine_lock = threading.Lock()
+
+
 
 # 初始化app
 app = FastAPI()
@@ -24,36 +29,78 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-use_gpu = os.getenv('USE_CUDA') == 'gpu'
-print(f'use gpu: {use_gpu}')
 
-# 普通表格
+
+
 table_engine = PPStructure(layout=False,
                            table=True,
-                           use_gpu=use_gpu,
+                           use_gpu=True,
                            show_log=True,
-                           det_model_dir="models/det/det_table_v2",
-                           rec_model_dir="./models/rec/rec_table_v1",
-                           table_model_dir="models/table/SLAnet_v1")
-
-# 长度较长表格
-table_engine1 = PPStructure(layout=False,
-                            table=True,
-                            use_gpu=use_gpu,
-                            show_log=True,
-                            det_model_dir="models/det/det_table_v1",
-                            rec_model_dir="./models/rec/rec_table_v1",
-                            table_model_dir="./models/table/SLAnet_v1")
-
-# 针对某些特殊情况的补充模型
-table_engine2 = PPStructure(layout=False,
-                            table=True,
-                            use_gpu=use_gpu,
-                            show_log=True,
-                            det_model_dir="models/det/det_table_v3",
-                            rec_model_dir="./models/rec/rec_table_v1",
-                            table_model_dir="./models/table/SLAnet_v1")
-
+                           # det_model_dir="models/det/det_table_v2",
+                           # rec_model_dir="models/rec/rec_table_v1",
+                           table_model_dir="models/table/SLANet_829")
+
+# # 普通表格
+# table_engine = PPStructure(layout=False,
+#                            table=True,
+#                            use_gpu=use_gpu,
+#                            show_log=True,
+#                            det_model_dir="models/det/det_table_v2",
+#                            rec_model_dir="./models/rec/rec_table_v1",
+#                            table_model_dir="models/table/SLANet_v2")
+#
+# # 长度较长表格
+# table_engine1 = PPStructure(layout=False,
+#                             table=True,
+#                             use_gpu=use_gpu,
+#                             show_log=True,
+#                             det_model_dir="models/det/det_table_v1",
+#                             rec_model_dir="./models/rec/rec_table_v1",
+#                             table_model_dir="./models/table/SLAnet_v1")
+#
+# # 针对某些特殊情况的补充模型
+# table_engine2 = PPStructure(layout=False,
+#                             table=True,
+#                             use_gpu=use_gpu,
+#                             show_log=True,
+#                             det_model_dir="models/det/det_table_v3",
+#                             rec_model_dir="./models/rec/rec_table_v1",
+#                             table_model_dir="./models/table/SLAnet_v1")
+#
+#
+#
+
+# 用于判断各个角度table的识别效果,识别的字段越多,效果越好
+def cal_html_to_chs(html):
+    res = []
+    rows = re.split('<tr>', html)
+    for row in rows:
+        row = re.split('<td>', row)
+        cells = list(map(lambda x: x.replace('</td>', '').replace('</tr>', ''), row))
+        rec_str = ''.join(cells)
+        for tag in ['<html>', '</html>', '<body>', '</body>', '<table>', '</table>', '<tbody>', '</tbody>']:
+            rec_str = rec_str.replace(tag, '')
+
+        res.append(rec_str)
+
+    rec_res = ''.join(res).replace(' ', '')
+    rec_res = re.split('<tdcolspan="\w+">', rec_res)
+    rec_res = ''.join(rec_res).replace(' ', '')
+    print(rec_res)
+    return len(rec_res)
+
+
+def table_res(im, ROTATE=-1):
+    im = im.copy()
+    if ROTATE >= 0:
+        im = cv2.rotate(im, ROTATE)
+    try:
+        table_engine_lock.acquire()
+        res = table_engine(im)
+    finally:
+        table_engine_lock.release()
+    html = res[0]['res']['html']
+    return res, html
 
 class TableInfo(BaseModel):
     image: str
@@ -62,30 +109,27 @@ class TableInfo(BaseModel):
 
 @app.get("/ping")
 def ping():
-    return 'pong!'
+    return 'pong!!!!!!!!!'
 
 
 @app.post("/ocr_system/table")
 @web_try()
 def table(image: TableInfo):
     img = base64_to_np(image.image)
-    if image.det == 'no':
-        res = table_engine(img)
-    elif image.det == 'yes':
-        res = table_engine1(img)
-    elif image.det == 'spe':
-        res = table_engine2(img)
-    return res[0]['res']
-
-
-if __name__ == '__main__':
-    import uvicorn
-    import argparse
-
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--host', default='0.0.0.0')
-    parser.add_argument('--port', default=8080)
-    opt = parser.parse_args()
-
-    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
-    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
+    res_len = 0
+    res = None
+    for i in [-1, 0, 1, 2]:
+        _res, html = table_res(img, i)
+        print(html)
+        _res_len = cal_html_to_chs(html)
+        if _res_len > res_len:
+            res = _res
+            res_len = _res_len
+
+    if res:
+        return res[0]['res']
+    else:
+        raise Exception('无法识别')
+
+
+print('table system init success!')

+ 6 - 0
sx_utils/sximage.py

@@ -7,3 +7,9 @@ def base64_to_np(img_data):
     img_data = img_data.split(',',1)[-1]
     return cv2.imdecode(np.fromstring(b64decode(img_data), dtype=np.uint8), color_image_flag)
 
+def base64_cv2(base64_str):
+    base64_str = base64_str.split(',', 1)[-1]
+    img_str = base64.b64decode(base64_str)
+    np_arr = np.fromstring(img_str, np.int8)
+    image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
+    return image