Browse Source

first commit

liweiquan 4 months ago
commit
c880124636

BIN
.DS_Store


+ 99 - 0
Dockerfile

@@ -0,0 +1,99 @@
+FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04 AS builder
+
+RUN sed -i 's#archive.ubuntu.com#mirrors.aliyun.com#g' /etc/apt/sources.list  \
+    && sed -i 's#security.ubuntu.com#mirrors.aliyun.com#g' /etc/apt/sources.list
+
+ENV LANG=zh_CN.UTF-8 LANGUAGE=zh_CN:zh LC_ALL=zh_CN.UTF-8 DEBIAN_FRONTEND=noninteractive
+
+RUN rm -rf  /etc/apt/sources.list.d/  && apt update
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    supervisor \
+    iputils-ping \
+    wget \
+    zsh \
+    build-essential \
+    cmake \
+    git \
+    curl \
+    vim \
+    ca-certificates \
+    libjpeg-dev \
+    zip \
+    unzip \
+    libpng-dev \
+    openssh-server \
+    autojump \
+    language-pack-zh-hans \
+    ttf-wqy-zenhei \
+    libgl1-mesa-glx  \
+    libglib2.0-0 \
+    locales &&\
+    rm -rf /var/lib/apt/lists/*
+
+
+RUN locale-gen zh_CN.UTF-8
+RUN dpkg-reconfigure locales
+
+
+CMD ["supervisord", "-n"]
+
+FROM builder as builder1
+
+ENV PYTHON_VERSION 3
+RUN chsh -s `which zsh`
+RUN curl -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda${PYTHON_VERSION}-latest-Linux-x86_64.sh  && \
+    chmod +x ~/miniconda.sh && \
+    ~/miniconda.sh -b -p /opt/conda && \
+    rm ~/miniconda.sh
+
+RUN ln /opt/conda/bin/conda /usr/local/bin/conda
+RUN conda init zsh
+RUN conda install mamba -n base -c conda-forge
+RUN ln /opt/conda/bin/mamba /usr/local/bin/mamba && mamba init zsh
+
+
+
+FROM builder1 as builder2
+
+RUN apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*
+RUN mkdir /var/run/sshd
+RUN echo 'root:root' | chpasswd
+RUN sed -i 's/.*PermitRootLogin .*/PermitRootLogin yes/' /etc/ssh/sshd_config
+# SSH login fix. Otherwise user is kicked off after login
+RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
+
+RUN echo "\
+[program:sshd] \n\
+command=/usr/sbin/sshd -D\n\
+autorestart=True\n\
+autostart=True\n\
+redirect_stderr = true\n\
+" > /etc/supervisor/conf.d/sshd.conf
+
+EXPOSE 22
+
+FROM builder2 as builder3
+
+WORKDIR /workspace
+ADD environment.yml /environment.yml
+RUN mamba update -n base -c defaults conda -y
+RUN mamba env create -f /environment.yml && rm -rf /root/.cache
+
+RUN echo "\
+[program:be]\n\
+directory=/workspace\n\
+command=/opt/conda/envs/py38/bin/gunicorn server:app --workers 1 --worker-class=uvicorn.workers.UvicornWorker  --bind 0.0.0.0:8080 --reload \n\
+autorestart=true\n\
+startretries=0\n\
+redirect_stderr=true\n\
+stdout_logfile=/var/log/be.log\n\
+stdout_logfile_maxbytes=50MB\n\
+environment=PYTHONUNBUFFERED=1\n\
+" > /etc/supervisor/conf.d/be.conf
+
+ARG VERSION
+ENV USE_CUDA $VERSION
+ADD . /workspace
+EXPOSE 8080
+

+ 17 - 0
Makefile

@@ -0,0 +1,17 @@
+NAME=ppocr
+VERSION=latest
+BUILD_TIME      := $(shell date "+%F %T")
+COMMIT_SHA1     := $(shell git rev-parse HEAD)
+AUTHOR          := $(shell git show -s --format='%an')
+
+
+.PHONY: all cpu gpu
+
+all: gpu cpu
+gpu:
+	@docker build -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):gpu --build-arg VERSION=gpu .
+	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):gpu
+
+cpu:
+	@docker build -f cpu.Dockerfile -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu --build-arg VERSION=cpu .
+	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu

+ 33 - 0
README.md

@@ -0,0 +1,33 @@
+# 通用OCR
+
+基于paddlepaddle 图像文本识别通用OCR
+
+## 环境
+
+- python >= 3.7
+- paddleocr
+- paddlepaddle
+- cv2
+
+```shell
+conda env create -f environment.yml
+```
+
+## 服务端
+
+```shell
+# port 8080
+python server.py --port 8080
+```
+
+## 单元测试
+
+```shell
+ python -m unittest discover testing '*_test.py' -v
+```
+
+## 镜像打包
+
+```shell
+make all
+```

BIN
core/.DS_Store


+ 22 - 0
core/ocr.py

@@ -0,0 +1,22 @@
+from dataclasses import dataclass
+from typing import Any, List
+import numpy as np
+from paddleocr import PaddleOCR
+
+@dataclass
+class OcrRes:
+    confidence: float
+    text: str
+    txt_region: List[List[int]]
+
+
+@dataclass
+class Ocr:
+    ocr: PaddleOCR
+
+    def predict(self, image: List[np.ndarray]):
+        if ocr_result := self.ocr.ocr(image):
+            res = [OcrRes(ocrr[0], ocrr[1][0], ocrr[1][1]) for ocrr in ocr_result]
+        else:
+            raise Exception('图片中未识别出文字')
+        return [res]

+ 21 - 0
docker-compose.yml

@@ -0,0 +1,21 @@
+version: '2'
+services:
+  ppocr:
+    hostname: ppocr
+    container_name: ppocr
+    restart: always
+    image: registry.cn-hangzhou.aliyuncs.com/sxtest/ppocr:gpu
+    privileged: true
+    ipc: host
+    tty: true
+    working_dir: /workspace
+    ports:
+      - '8077:8080'
+      - '2277:22'
+    volumes:
+      - ./:/workspace
+#    deploy:
+#      resources:
+#        reservations:
+#          devices:
+#            - capabilities: [gpu]

+ 29 - 0
environment.yml

@@ -0,0 +1,29 @@
+name: py38
+channels:
+  - http://mirrors.ustc.edu.cn/anaconda/pkgs/free/ # Anocanda清华镜像
+  - defaults
+  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
+dependencies:
+  - python=3.8
+  - ipython
+  - pip
+  - pip:
+      - cmake
+      - opencv-python==4.6.0.66
+      - cython
+      - paddleocr==2.5.0.3
+      - paddlehub==2.2.0
+      - fastapi==0.79.0
+      - uvicorn
+      - zhon
+      - pytest
+      - jinja2
+      - aiofiles
+      - python-multipart
+      - requests
+      - cpca
+      - gunicorn
+      - -i https://mirror.baidu.com/pypi/simple
+      - paddlepaddle-gpu==2.3.0.post110
+      - -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
+prefix: /opt/conda/envs/py38

BIN
models/cls/inference.pdiparams


BIN
models/cls/inference.pdiparams.info


BIN
models/cls/inference.pdmodel


BIN
models/det/inference.pdiparams


BIN
models/det/inference.pdiparams.info


BIN
models/det/inference.pdmodel


BIN
models/rec/inference.pdiparams


BIN
models/rec/inference.pdiparams.info


BIN
models/rec/inference.pdmodel


+ 11 - 0
run.py

@@ -0,0 +1,11 @@
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 64 - 0
server.py

@@ -0,0 +1,64 @@
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+from paddleocr import PaddleOCR
+from core.ocr import *
+from utils.image import *
+from utils.time import timeit
+from utils.web import web_try
+import os
+
+# 导入一些包
+
+
+app = FastAPI()
+
+origins = ["*"]
+
+# CORS 跨源资源共享
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=origins,
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+use_gpu = os.getenv('USE_CUDA') == 'gpu'
+print(f'use gpu: {use_gpu}')
+
+ocr = PaddleOCR(use_angle_cls=True,
+                rec_model_dir='models/rec/',
+                det_model_dir='models/det/',
+                cls_model_dir='models/cls/',
+                rec_char_type='ch',
+                use_gpu=use_gpu,
+                warmup=True)
+p = Ocr(ocr)
+class PPOcrInfo(BaseModel):
+    image: str
+
+
+# Get 健康检查
+@app.get("/ping")
+def ping():
+    return "pong!"
+
+
+@app.post("/ppocr")
+@timeit
+@web_try()
+def cet(request: Request, ppocr: PPOcrInfo):
+    return p.predict(base64_to_np(ppocr.image))
+
+
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

BIN
utils/.DS_Store


+ 0 - 0
utils/__init__.py


+ 12 - 0
utils/image.py

@@ -0,0 +1,12 @@
+from base64 import b64decode
+import numpy as np
+import cv2
+
+
+def base64_to_np(img_data):
+    color_image_flag = 1
+    img_data = img_data.split(',', 1)[-1]
+    # b64decode -> base64图片解码
+    # np.fromstring -> 从字符串中的文本数据初始化的新一维数组
+    # numpy.fromstring(string, dtype=float, count=-1, sep='')
+    return cv2.imdecode(np.fromstring(b64decode(img_data), dtype=np.uint8), color_image_flag)

+ 125 - 0
utils/time.py

@@ -0,0 +1,125 @@
+# coding=utf-8
+# Powered by SoaringNova Technology Company
+import errno
+import os
+import signal
+import time
+from collections import defaultdict
+from functools import wraps
+
+timer_counts = defaultdict(int)
+
+
+class TimeoutError(Exception):
+    pass
+
+
+def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
+    def decorator(func):
+        def _handle_timeout(signum, frame):
+            raise TimeoutError(error_message)
+
+        def wrapper(*args, **kwargs):
+            signal.signal(signal.SIGALRM, _handle_timeout)
+            signal.alarm(seconds)
+            try:
+                result = func(*args, **kwargs)
+            finally:
+                signal.alarm(0)
+            return result
+
+        return wraps(func)(wrapper)
+
+    return decorator
+
+
+class TIMELIMIT:
+    def __init__(self, limit_time=0):
+        self.st = None
+        self.et = None
+        self.limit_time = limit_time
+
+    def __enter__(self):
+        self.st = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.et = time.time()
+        dt = self.limit_time - (self.et - self.st) * 1000
+        if dt > 0: time.sleep(float(dt) / 1000)
+
+
+class TIMER:
+    total_time = {}  # type: dict
+
+    def __init__(self, tag='', enable_total=False, threshold_ms=0):
+        self.st = None
+        self.et = None
+        self.tag = tag
+        # self.tag = tag if not hasattr(g,'request_id') else '{} {}'.format(getattr(g,'request_id'),tag)
+
+        self.thr = threshold_ms
+        self.enable_total = enable_total
+        if self.enable_total:
+            if self.tag not in self.total_time.keys():
+                self.total_time[self.tag] = []
+
+    def __enter__(self):
+        self.st = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.et = time.time()
+        dt = (self.et - self.st) * 1000
+        if self.enable_total:
+            self.total_time[self.tag].append(dt)
+
+        if dt > self.thr:
+            print("{}: {}s".format(self.tag, round(dt / 1000, 4)))
+
+    @staticmethod
+    def output():
+        for k, v in TIMER.total_time.items():
+            print('{} : {}s, avg{}s'.format(k, round(sum(v) / 1000, 2), round(sum(v) / len(v) / 1000, 2)))
+
+
+def timeit(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        st = time.time()
+        ret = func(*args, **kwargs)
+        dt = time.time() - st
+        endpoint = '{}.{}'.format(func.__module__, func.__name__)
+        timer_counts[endpoint] += 1
+        print('{}[{}] finished, exec {}s'.format(endpoint, '%05d' % timer_counts[endpoint], round(dt, 4)))
+        return ret
+
+    return wrapper  # 返回
+
+
+# def timeit(func):
+#     @wraps(func)
+#     def wrapper(*args, **kwargs):
+#         endpoint = '{}.{}'.format(func.__module__, func.__name__)
+#         setattr(g,'request_id','{}[{}]'.format(endpoint,'%05d' % timer_counts[endpoint]))
+#         timer_counts[endpoint] += 1
+#         st = time.time()
+#         ret = func(*args, **kwargs)
+#         dt = time.time() - st
+#         print ('{} finished, exec {}s'.format(getattr(g,'request_id'), round(dt, 4)))
+#         return ret
+#
+#     return wrapper  # 返回
+
+def t2date(t):
+    import datetime
+    date = datetime.datetime.fromtimestamp(t)
+    return '{}_{}_{}_{}:{}:{}'.format(date.year, date.month, date.day, date.hour, date.minute, date.second)
+
+
+def day_begin(t):
+    dsecs = 24 * 3600
+    return (int(t) + 8 * 3600) // dsecs * dsecs - 8 * 3600
+
+
+def hour_begin(t):
+    hsecs = 3600
+    return (int(t) + 8 * 3600) // hsecs * hsecs - 8 * 3600

+ 43 - 0
utils/web.py

@@ -0,0 +1,43 @@
+import traceback
+import numpy as np
+from decorator import decorator
+from fastapi.responses import JSONResponse
+
+
+def json_compatible(data):
+    if isinstance(data,dict):
+        return {k:json_compatible(v) for k,v in data.items()}
+    if isinstance(data,bytes):
+        return str(data)
+    if isinstance(data,np.ndarray):
+        return data.tolist()
+    return data
+
+def web_try(exception_ret=None):
+    @decorator
+    def f(func, *args, **kwargs):
+        error_code = "000"
+        ret = None
+        msg = ''
+        try:
+            ret = func(*args, **kwargs)
+        except Exception as e:
+            msg = traceback.format_exc()
+            if len(e.args) > 0 and isinstance(e.args[0], int):
+                error_code = e.args[0]
+            else:
+                error_code = "101"
+            print('--------------------------------')
+            print ('Get Exception in web try :( \n{}\n'.format(msg))
+            print('--------------------------------')
+            if callable(exception_ret):
+                ret = exception_ret()
+            else:
+                ret = exception_ret
+        finally:
+            if ret is not None and isinstance(ret, JSONResponse):
+                return ret
+            return json_compatible({"status": error_code,
+                                    "result": ret,
+                                    "msg": msg.split('\n')[-2] if msg is not '' else msg})
+    return f