zeke-chin 2 лет назад
Сommit
d96a088213
57 измененных файлов с 1587 добавлено и 0 удалено
  1. BIN
      .DS_Store
  2. 3 0
      .gitignore
  3. 120 0
      Dockerfile
  4. 17 0
      Makefile
  5. 33 0
      README.md
  6. 0 0
      core/__init__.py
  7. 206 0
      core/direction.py
  8. 135 0
      core/line_parser.py
  9. 99 0
      core/ocr.py
  10. 177 0
      core/parser.py
  11. 118 0
      cpu.Dockerfile
  12. 21 0
      docker-compose.yml
  13. 29 0
      environment.yml
  14. BIN
      images/.DS_Store
  15. BIN
      images/8-16/.DS_Store
  16. BIN
      images/8-16/08_img.jpg
  17. BIN
      images/8-16/12_img.jpg
  18. BIN
      images/8-16/14_img.jpg
  19. BIN
      images/8-16/18_img.jpg
  20. BIN
      images/8-16/19_img.jpg
  21. BIN
      images/8-16/20_img.jpg
  22. BIN
      images/8-16/affa.jpg
  23. BIN
      images/8-16/fasdfas.jpg
  24. BIN
      images/8-16/hh.jpg
  25. BIN
      images/8-16/wgw.jpg
  26. BIN
      images/cet/.DS_Store
  27. BIN
      images/cet/01_img.jpg
  28. 10 0
      images/cet/01_img.json
  29. BIN
      images/cet/02_img.jpg
  30. 10 0
      images/cet/02_img.json
  31. BIN
      images/cet/03_img.jpg
  32. 10 0
      images/cet/03_img.json
  33. BIN
      images/cet/04_img.jpg
  34. 10 0
      images/cet/04_img.json
  35. BIN
      images/cet/05_img.jpg
  36. 10 0
      images/cet/05_img.json
  37. BIN
      images/cet/07_img.jpg
  38. 10 0
      images/cet/07_img.json
  39. BIN
      images/cet/08_img.jpg
  40. 10 0
      images/cet/08_img.json
  41. BIN
      images/cet/09_img.jpg
  42. 10 0
      images/cet/09_img.json
  43. BIN
      images/cet/10_img.jpg
  44. 10 0
      images/cet/10_img.json
  45. BIN
      images/tem/00288070A20资格证书001(1).jpg
  46. BIN
      images/tem/10022276A20资格证书003(1).jpg
  47. BIN
      images/tem/10022507A20资格证书001(1).jpg
  48. 11 0
      run.py
  49. 122 0
      server.py
  50. 0 0
      sx_utils/__init__.py
  51. 12 0
      sx_utils/sximage.py
  52. 125 0
      sx_utils/sxtime.py
  53. 43 0
      sx_utils/sxweb.py
  54. BIN
      testing/.DS_Store
  55. 0 0
      testing/__init__.py
  56. 180 0
      testing/true_test.py
  57. 46 0
      testing/utils.py

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+.idea
+convert_*
+generate_test.py

+ 120 - 0
Dockerfile

@@ -0,0 +1,120 @@
+FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04 AS builder
+
+RUN sed -i 's#archive.ubuntu.com#mirrors.aliyun.com#g' /etc/apt/sources.list  \
+    && sed -i 's#security.ubuntu.com#mirrors.aliyun.com#g' /etc/apt/sources.list
+
+ENV LANG=zh_CN.UTF-8 LANGUAGE=zh_CN:zh LC_ALL=zh_CN.UTF-8 DEBIAN_FRONTEND=noninteractive
+
+RUN rm -rf  /etc/apt/sources.list.d/  && apt update
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    supervisor \
+    iputils-ping \
+    wget \
+    zsh \
+    build-essential \
+    cmake \
+    git \
+    curl \
+    vim \
+    ca-certificates \
+    libjpeg-dev \
+    zip \
+    unzip \
+    libpng-dev \
+    openssh-server \
+    autojump \
+    language-pack-zh-hans \
+    ttf-wqy-zenhei \
+    libgl1-mesa-glx  \
+    libglib2.0-0 \
+    locales &&\
+    rm -rf /var/lib/apt/lists/*
+
+
+RUN locale-gen zh_CN.UTF-8
+RUN dpkg-reconfigure locales
+
+
+CMD ["supervisord", "-n"]
+
+FROM builder as builder1
+
+ENV PYTHON_VERSION 3
+RUN chsh -s `which zsh`
+RUN curl -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda${PYTHON_VERSION}-latest-Linux-x86_64.sh  && \
+    chmod +x ~/miniconda.sh && \
+    ~/miniconda.sh -b -p /opt/conda && \
+    rm ~/miniconda.sh
+
+RUN ln /opt/conda/bin/conda /usr/local/bin/conda
+RUN conda init zsh
+RUN conda install mamba -n base -c conda-forge
+RUN ln /opt/conda/bin/mamba /usr/local/bin/mamba && mamba init zsh
+
+
+
+FROM builder1 as builder2
+
+RUN apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*
+RUN mkdir /var/run/sshd
+RUN echo 'root:root' | chpasswd
+RUN sed -i 's/.*PermitRootLogin .*/PermitRootLogin yes/' /etc/ssh/sshd_config
+# SSH login fix. Otherwise user is kicked off after login
+RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
+
+RUN echo "\
+[program:sshd] \n\
+command=/usr/sbin/sshd -D\n\
+autorestart=True\n\
+autostart=True\n\
+redirect_stderr = true\n\
+" > /etc/supervisor/conf.d/sshd.conf
+
+EXPOSE 22
+
+FROM builder2 as builder3
+
+WORKDIR /workspace
+ADD environment.yml /environment.yml
+RUN sed -i 's#- paddlepaddle#- paddlepaddle-gpu==2.3.0.post110#g' /environment.yml && cat /environment.yml
+RUN mamba update -n base -c defaults conda -y && mamba env create -f /environment.yml && rm -rf /root/.cache
+
+# RUN /opt/conda/envs/py38/bin/python -m ipykernel install --name py38 --display-name "py38"
+# RUN echo "c.MultiKernelManager.default_kernel_name = 'py38'">>/root/.jupyter/jupyter_notebook_config.py
+RUN echo "\
+[program:be]\n\
+directory=/workspace\n\
+command=/opt/conda/envs/py38/bin/gunicorn server:app --workers 1 --worker-class=uvicorn.workers.UvicornWorker  --bind 0.0.0.0:8080 --reload \n\
+autorestart=true\n\
+startretries=0\n\
+redirect_stderr=true\n\
+stdout_logfile=/var/log/be.log\n\
+stdout_logfile_maxbytes=50MB\n\
+environment=PYTHONUNBUFFERED=1\n\
+" > /etc/supervisor/conf.d/be.conf
+
+ARG VERSION
+ENV USE_CUDA $VERSION
+Add . /workspace
+EXPOSE 8080
+
+
+
+# RUN mamba install -y jupyterlab -n base && mamba init zsh
+# RUN /opt/conda/bin/jupyter notebook --generate-config && \
+#     echo "c.NotebookApp.password='argon2:\$argon2id\$v=19\$m=10240,t=10,p=8\$+zIUCF9Uk2FiCHlV8njX5A\$I5Mm/64DORArcXYTXWRVng'">>/root/.jupyter/jupyter_notebook_config.py
+
+
+# RUN mkdir -p /data && echo "\
+# [program:jupyter]\n\
+# directory=/data\n\
+# command=/opt/conda/bin/jupyter lab --ip 0.0.0.0 --port 8888 --allow-root --no-browser \n\
+# autorestart=true\n\
+# startretries=0\n\
+# redirect_stderr=true\n\
+# stdout_logfile=/dev/stdout\n\
+# stdout_logfile_maxbytes=0\n\
+# " > /etc/supervisor/conf.d/jupyter.conf
+
+# EXPOSE 8888

+ 17 - 0
Makefile

@@ -0,0 +1,17 @@
+NAME=cet
+VERSION=latest
+BUILD_TIME      := $(shell date "+%F %T")
+COMMIT_SHA1     := $(shell git rev-parse HEAD)
+AUTHOR          := $(shell git show -s --format='%an')
+
+
+.PHONY: all cpu gpu
+
+all: gpu cpu
+gpu:
+	@docker build -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):gpu --build-arg VERSION=gpu .
+	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):gpu
+
+cpu:
+	@docker build -f cpu.Dockerfile -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu --build-arg VERSION=cpu .
+	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):cpu

+ 33 - 0
README.md

@@ -0,0 +1,33 @@
+# 英语等级证书识别
+
+基于paddleocr v2, 用于识别英语等级证书。
+
+## 环境
+
+- python >= 3.7
+- paddleocr
+- paddlepaddle
+- cv2
+
+```shell
+conda env create -f environment.yml
+```
+
+## 服务端
+
+```shell
+# port 8080
+python server.py --port 8080
+```
+
+## 单元测试
+
+```shell
+ python -m unittest discover testing '*_test.py' -v
+```
+
+## 镜像打包
+
+```shell
+make all
+```

+ 0 - 0
core/__init__.py


+ 206 - 0
core/direction.py

@@ -0,0 +1,206 @@
+import re
+from dataclasses import dataclass
+from enum import Enum
+from typing import Tuple, List
+
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+
+from core.line_parser import LineParser
+
+
+class Direction(Enum):
+    TOP = 0
+    RIGHT = 1
+    BOTTOM = 2
+    LEFT = 3
+
+
+# 父类
+class OcrAnchor(object):
+    # 输入识别anchor的名字, 如身份证号
+    def __init__(self, name: str, d: List[Direction]):
+        self.name = name
+        # anchor位置
+        self.direction = d
+
+        def t_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[1] < c[1] else 2
+            else:
+                return 1 if anchor[0] > c[0] else 3
+
+        def l_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[0] < c[0] else 2
+            else:
+                return 1 if anchor[1] < c[1] else 3
+
+        def b_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[1] > c[1] else 2
+            else:
+                return 1 if anchor[0] < c[0] else 3
+
+        def r_func(anchor, c, is_horizontal):
+            if is_horizontal:
+                return 0 if anchor[0] > c[0] else 2
+            else:
+                return 1 if anchor[1] > c[1] else 3
+
+        self.direction_funcs = {
+            Direction.TOP: t_func,
+            Direction.BOTTOM: b_func,
+            Direction.LEFT: l_func,
+            Direction.RIGHT: r_func,
+        }
+
+    # 获取中心区域坐标 -> (x, y)
+    def get_rec_area(self, res) -> Tuple[float, float]:
+        """获得整张身份证的识别区域, 返回识别区域的中心点"""
+        boxes = []
+        for row in res:
+            for r in row:
+                boxes.extend(r.box)
+        boxes = np.stack(boxes)
+        l, t = np.min(boxes, 0)
+        r, b = np.max(boxes, 0)
+        # 识别区域的box
+        # big_box = [[l, t], [r, t], [r, b], [l, b]]
+        # w, h = (r - l, b - t)
+        return (l + r) / 2, (t + b) / 2
+
+    # 判断是否是 锚点
+    def is_anchor(self, txt, box) -> bool:
+        pass
+
+    # 找 锚点 -> 锚点坐标
+    def find_anchor(self, res) -> Tuple[bool, float, float]:
+        """
+        寻找锚点 中心点坐标
+        """
+        for row in res:
+            for r in row:
+                txt = r.txt.replace('-', '').replace(' ', '')
+                box = r.box
+                if self.is_anchor(txt, box):
+                    l, t = np.min(box, 0)
+                    r, b = np.max(box, 0)
+                    return True, (l + r) / 2, (t + b) / 2
+        return False, 0., 0.
+
+    # 定位 锚点 -> 角度
+    # -> 锚点(x, y)  pic(x, y) is_horizontal
+    def locate_anchor(self, res, is_horizontal) -> int:
+        found, id_cx, id_cy = self.find_anchor(res)
+
+        # 如果识别不到身份证号
+        if not found: raise Exception(f'识别不到anchor{self.name}')
+        cx, cy = self.get_rec_area(res)
+        # print(f'id_cx: {id_cx}, id_cy: {id_cy}')
+        # print(f'cx: {cx}, cy: {cy}')
+        pre = None
+        for d in self.direction:
+            f = self.direction_funcs.get(d, None)
+            angle = f((id_cx, id_cy), (cx, cy), is_horizontal)
+            if pre is None:
+                pre = angle
+            else:
+                if angle != pre:
+                    raise Exception('angle is not compatiable')
+        return pre
+
+        # if is_horizontal:
+        #     # 如果是水平的,身份证号的位置在相对识别区域的下方,方向则为0度,否则是180度
+        #     return 0 if id_cy > cy else 2
+        # else:
+        #     # 如果是竖直的,身份证号的相对位置如果在左边,方向为90度,否则270度
+        #     return 1 if id_cx < cx else 3
+
+
+# 子类1 人像面
+class CETAnchor(OcrAnchor):
+    def __init__(self, name: str, d: List[Direction]):
+        super(CETAnchor, self).__init__(name, d)
+
+    def is_anchor(self, txt, box) -> bool:
+        txts = re.findall('全国大学英语', txt)
+        if len(txts) > 0:
+            return True
+        return False
+
+    def locate_anchor(self, res, is_horizontal) -> int:
+        return super(CETAnchor, self).locate_anchor(res, is_horizontal)
+
+
+# 子类2 国徽面
+class TEMAnchor(OcrAnchor):
+    def __init__(self, name: str, d: List[Direction]):
+        super(TEMAnchor, self).__init__(name, d)
+
+    def is_anchor(self, txt, box) -> bool:
+        txts = re.findall('证书编号', txt)
+        if len(txts) > 0:
+            return True
+        return False
+
+    def locate_anchor(self, res, is_horizontal) -> int:
+        return super(TEMAnchor, self).locate_anchor(res, is_horizontal)
+
+
+# 调用以上 🔧工具
+# <- ocr_生数据
+# == ocr_熟数据(行处理后)
+# -> 角度0/1/2/3
+def detect_angle(result, ocr_anchor: OcrAnchor):
+    filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
+    lp = LineParser(result, filters)
+    res = lp.parse()
+    print('------ angle ocr -------')
+    print(res)
+    print('------ angle ocr -------')
+    is_horizontal = lp.is_horizontal
+    return ocr_anchor.locate_anchor(res, is_horizontal)
+
+
+@dataclass
+class AngleDetector(object):
+    """
+    角度检测器
+    """
+    ocr: PaddleOCR
+
+    # 角度检测器
+    # <- img(cv2格式)  img_type
+    # == result <- img(cv2)
+    # -> angle       result(ocr生)
+    def detect_angle(self, img):
+        # image_type = int(image_type)
+        # result = self.ocr.ocr(img, cls=True)
+
+        image_type, result = self.detect_img(img)
+
+        ocr_anchor = CETAnchor('CET', [Direction.TOP]) if image_type == 0 else TEMAnchor('TEM', [
+            Direction.BOTTOM])
+
+        try:
+            angle = detect_angle(result, ocr_anchor)
+            return angle, result, image_type
+
+        except Exception as e:
+            print(e)
+            # 如果第一次识别不到,旋转90度再识别
+            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+            result = self.ocr.ocr(img, cls=True)
+            angle = detect_angle(result, ocr_anchor)
+            # 旋转90度之后要重新计算角度
+            return (angle - 1 + 4) % 4, result, image_type
+
+    def detect_img(self, img):
+        result = self.ocr.ocr(img, cls=True)
+        for res in result:
+            if "报告单" in res[1][0]:
+                return 0, result
+        raise Exception("不支持专四专八")
+        # return 1, result

+ 135 - 0
core/line_parser.py

@@ -0,0 +1,135 @@
+import numpy as np
+from dataclasses import dataclass
+
+
+# result 对象
+@dataclass
+class OcrResult(object):
+    box: np.ndarray
+    txt: str
+    conf: float
+
+    def __hash__(self):
+        return hash(repr(self))
+
+    def __repr__(self):
+        return f'txt: {self.txt}, box: {self.box.tolist()}, conf: {self.conf}'
+
+    @property
+    def lt(self):
+        l, t = np.min(self.box, 0)
+        return [l, t]
+
+    @property
+    def rb(self):
+        r, b = np.max(self.box, 0)
+        return [r, b]
+
+    @property
+    def wh(self):
+        l, t = self.lt
+        r, b = self.rb
+        return [r - l, b - t]
+
+    @property
+    def area(self):
+        w, h = self.wh
+        return w * h
+
+    @property
+    def is_slope(self):
+        p0 = self.box[0]
+        p1 = self.box[1]
+        if p0[0] == p1[0]:
+            return False
+        slope = abs(1. * (p0[1] - p1[1]) / (p0[0] - p1[0]))
+        return 0.4 < slope < 2.5
+
+    @property
+    def center(self):
+        l, t = self.lt
+        r, b = self.rb
+        return [(r + l) / 2, (b + t) / 2]
+
+    def one_line(self, b, is_horizontal, eps: float = 20.0) -> bool:
+        y_idx = 0 + is_horizontal
+        x_idx = 1 - y_idx
+        if b.lt[x_idx] < self.lt[x_idx] < self.rb[x_idx] < b.rb[x_idx]: return False
+        if self.lt[x_idx] < b.lt[x_idx] < b.rb[x_idx] < self.rb[x_idx]: return False
+        eps = 0.25 * (self.wh[y_idx] + b.wh[y_idx])
+        dist = abs(self.center[y_idx] - b.center[y_idx])
+        return dist < eps
+
+
+# 行处理器
+class LineParser(object):
+    def __init__(self, ocr_raw_result, filters=None):
+        if filters is None:
+            filters = [lambda x: x.is_slope]
+        self.ocr_res = []
+        for re in ocr_raw_result:
+            o = OcrResult(np.array(re[0]), re[1][0], re[1][1])
+            if any([f(o) for f in filters]): continue
+            self.ocr_res.append(o)
+        # for f in filters:
+        #     self.ocr_res = list(filter(f, self.ocr_res))
+        self.ocr_res = sorted(self.ocr_res, key=lambda x: x.area, reverse=True)
+        self.eps = self.avg_height * 0.7
+
+    @property
+    def is_horizontal(self):
+        res = self.ocr_res
+        wh = np.stack([np.abs(np.array(r.lt) - np.array(r.rb)) for r in res])
+        return np.sum(wh[:, 0] > wh[:, 1]) > np.sum(wh[:, 0] < wh[:, 1])
+
+    @property
+    def avg_height(self):
+        idx = self.is_horizontal + 0
+        return np.mean(np.array([r.wh[idx] for r in self.ocr_res]))
+
+    # 整体置信度
+    @property
+    def confidence(self):
+        return np.mean([r.conf for r in self.ocr_res])
+
+    # 处理器函数
+    def parse(self, eps=40.0):
+        # 存返回值
+        res = []
+
+        # 需要 处理的 OcrResult 对象  的长度
+        length = len(self.ocr_res)
+
+        # 如果字段数 小于等于1 就抛出异常
+        if length <= 1:
+            raise Exception('无法识别')
+
+        # 遍历数组 并处理他
+        for i in range(length):
+            # 拿出 OcrResult对象的 第i值 -暂存-
+            res_i = self.ocr_res[i]
+
+            # 这次的 res_i 之前已经在结果集中,就继续下一个
+            if any(map(lambda x: res_i in x, res)): continue
+
+            # set() -> {}
+            # 初始化一个集合 即-输出-
+            res_row = set()
+
+            for j in range(i, length):
+                res_j = self.ocr_res[j]
+                # 这次的 res_i 之前已经在结果集中,就继续下一个
+                if any(map(lambda x: res_j in x, res)): continue
+
+                if res_i.one_line(res_j, self.is_horizontal, self.eps):
+                    # LineParser 对象  不可以直接加入字典
+
+                    res_row.add(res_j)
+            res.append(res_row)
+        idx = self.is_horizontal + 0
+        res = sorted([sorted(list(r), key=lambda x: x.lt[1 - idx]) for r in res], key=lambda x: x[0].lt[idx])
+        for row in res:
+            print('---')
+            print(''.join([r.txt for r in row]))
+        return res
+

+ 99 - 0
core/ocr.py

@@ -0,0 +1,99 @@
+from dataclasses import dataclass
+from typing import Any
+
+from core.line_parser import LineParser
+from core.parser import *
+from core.direction import *
+import numpy as np
+from paddleocr import PaddleOCR
+
+
+# <- 传入pic pic_type
+# 1. 旋转pic  (to 正向)
+# 2. 重写识别pic  (get res)
+# 3. 行处理res  (get res)
+# 4. 对res字段逻辑识别  (get dict)
+# -> dict
+# 身份证OCR
+@dataclass
+class CetOcr:
+    ocr: PaddleOCR
+    # 角度探测器
+    angle_detector: AngleDetector
+
+    # 检测
+    # <- 传入pic pic_type
+    # -> dict
+    def predict(self, image: np.ndarray) -> ():
+
+        # 旋转后img angle result(生ocr)
+        image, angle, result, image_type = self._pre_process(image)
+        print(f'---------- detect angle: {angle} 角度 --------')
+
+        return self._post_process(result, angle, image_type)
+
+    # 预处理(旋转图片)
+    # <- img(cv2) img_type
+    # -> 正向的img(旋转后) 源img角度 result(ocr生)
+    def _pre_process(self, image) -> (np.ndarray, int, Any):
+        # pic角度 result(ocr生)
+        angle, result, image_type = self.angle_detector.detect_angle(image)
+
+        if angle == 1:
+            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
+        print(angle)  # 逆时针
+        if angle == 2:
+            image = cv2.rotate(image, cv2.ROTATE_180)
+        if angle == 3:
+            image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
+
+        return image, angle, result, image_type
+
+    # 获取模型检测结果
+    def _ocr(self, image):
+        result = self.ocr.ocr(image, cls=True)
+        print("------------------")
+        print(result)
+        if not result:
+            raise Exception('无法识别')
+        confs = [line[1][1] for line in result]
+
+        # 将检测到的文字放到一个列表中
+        txts = [line[1][0] for line in result]
+        # print("......................................")
+        # print(txts)
+        # print("......................................")
+        return txts, confs, result
+
+    # <- result(正向img_生ocr) angle img_type
+    # == 对 正向img_res 进行[行处理]
+    # -> 最后要返回的结果 dict
+    def _post_process(self, result, angle: int, image_type):
+        filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
+        line_parser = LineParser(result, filters)
+        line_result = line_parser.parse()
+        print('-------------')
+        print(line_result)
+        print('-------------')
+        conf = line_parser.confidence
+
+        if int(image_type) == 0:
+            parser = CETParser(line_result)
+        elif int(image_type) == 1:
+            parser = TEMParser(line_result)
+        else:
+            raise Exception('无法识别')
+
+        # 字段逻辑处理后对res(dict)
+        ocr_res = parser.parse()
+
+        res = {
+            "confidence": conf,
+            "card_type": str(image_type),
+            "orientation": angle,  # 原angle是逆时针,转成顺时针
+            **ocr_res
+        }
+        print(res)
+        return res
+
+    # def _get_type(self, image) -> int:

+ 177 - 0
core/parser.py

@@ -0,0 +1,177 @@
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import List
+
+import cpca
+import numpy as np
+from zhon.hanzi import punctuation
+
+from core.line_parser import OcrResult
+
+
+@dataclass
+class RecItem:
+    text: str = ''
+    confidence: float = 0.
+
+    def to_dict(self):
+        return {"text": self.text.strip(), "confidence": np.nan_to_num(self.confidence)}
+
+
+class Parser(object):
+    def __init__(self, ocr_results: List[List[OcrResult]]):
+        self.result = ocr_results
+        self.res = defaultdict(RecItem)
+        self.keys = ["name", "id", "language", "level", "exam_time", "score"]
+        for key in self.keys:
+            self.res[key] = RecItem()
+
+        ch = re.compile(u'[\u4e00-\u9fa5+\u0030-\u0039]')
+        for item in self.result:
+            tail = ['', 1.]
+            for k in range(len(item)):
+                item[k].txt = ''.join(re.findall(ch, item[k].txt))
+                tail[0] = tail[0] + item[k].txt
+                tail[1] = tail[1] + item[k].conf
+            tail[1] = (tail[1] - 1.) / len(item)
+            item.append(tail)
+
+        for i in range(len(self.result)):
+            res = self.result[i]
+            txt = res[-1][0]
+            if "口试" in txt:
+                self.result = self.result[:i]
+                break
+
+    def parse(self):
+        return self.res
+
+
+class CETParser(Parser):
+    """
+    姓名
+    """
+
+    def __init__(self, ocr_results: List[List[OcrResult]]):
+        Parser.__init__(self, ocr_results)
+
+    def extract_zhon(self):
+        if res := re.findall('[\u4E00-\u9FA5]+', self):
+            return res[0]
+
+    def name(self):
+        name_val = ''
+        conf = 0.
+        is_name = False
+        for i in range(len(self.result)):
+            res = self.result[i]
+            txt = res[-1][0]
+            conf = res[-1][1]
+            for s in range(len(txt)):
+                if txt[s] == "名" and s < 2 and "姓名" in txt:
+                    is_name = True
+            if is_name:
+                name_val = txt.split("姓名")[-1]
+                break
+
+        if len(name_val) < 5:
+            self.res["name"] = RecItem(name_val, conf)
+        else:
+            point_unicode = ["\u2E31", "\u2218", "\u2219", "\u22C5", "\u25E6", "\u2981",
+                             "\u00B7", "\u0387", "\u05BC", "\u16EB", "\u2022", "\u2027",
+                             "\u2E30", "\uFF0E", "\u30FB", "\uFF65", "\u10101"]
+            for item in point_unicode:
+                point = re.findall(item, name_val)
+                if len(point) != 0:
+                    name_list = name_val.split(point[0])
+                    self.res['name'] = RecItem(name_list[0] + '\u00B7' + name_list[1], conf)
+                    return
+
+    def id(self):
+        """
+        身份证号码
+        """
+        for i in range(len(self.result)):
+            res = self.result[i]
+            txt = res[-1][0]
+            conf = res[-1][1]
+
+            id_num = re.findall("\d{10,18}[X|x|×]*", txt)
+            if id_num and len(id_num[0]) == 18:
+                self.res['id'] = RecItem(id_num[0].replace('x', "X").replace('×', "X"), conf)
+                break
+
+    def language(self):
+        """
+        语言
+        """
+        self.res['language'] = RecItem("英语", 1.)
+
+    def level(self):
+        """
+        等级
+        """
+        for i in range(len(self.result)):
+            res = self.result[i]
+            txt = res[-1][0]
+            conf = res[-1][1]
+
+            if "四级" in txt:
+                self.res['level'] = RecItem("CET4", conf)
+                return
+            elif "六级" in txt:
+                self.res['level'] = RecItem("CET6", conf)
+                return
+        raise Exception("四六级无法识别")
+
+    def exam_time(self):
+        """
+        考试时间
+        """
+        for i in range(len(self.result)):
+            res = self.result[i]
+            txt = res[-1][0]
+            conf = res[-1][1]
+
+            if "时间" in txt:
+                txt = txt.split("时间")[-1]
+                self.res["exam_time"] = RecItem(self.to_data(txt), conf)
+                return
+
+    def score(self):
+        """
+        总分
+        """
+        for i in range(len(self.result)):
+            res = self.result[i]
+            txt = res[-1][0]
+            conf = res[-1][1]
+
+            if "时间" in txt:
+                txt = txt.split("月")[-1][:3]
+                self.res["score"] = RecItem(txt, conf)
+                return
+
+    def to_data(self, txt):
+        date_in = re.findall(r"\d+", txt)
+        return f'{date_in[0][-4:]}年{date_in[1]}月'
+
+    def parse(self):
+        self.name()
+        self.id()
+        self.language()
+        self.level()
+        self.exam_time()
+        self.score()
+        return {key: self.res[key].to_dict() for key in self.keys}
+
+
+class TEMParser(Parser):
+    def __init__(self, ocr_results: List[List[OcrResult]]):
+        Parser.__init__(self, ocr_results)
+
+    def parse(self):
+        # self.expire_date()
+
+        return {key: self.res[key].to_dict() for key in self.keys}

+ 118 - 0
cpu.Dockerfile

@@ -0,0 +1,118 @@
+FROM ubuntu:18.04 AS builder
+
+RUN sed -i 's#archive.ubuntu.com#mirrors.aliyun.com#g' /etc/apt/sources.list  \
+    && sed -i 's#security.ubuntu.com#mirrors.aliyun.com#g' /etc/apt/sources.list
+
+ENV LANG=zh_CN.UTF-8 LANGUAGE=zh_CN:zh LC_ALL=zh_CN.UTF-8 DEBIAN_FRONTEND=noninteractive
+
+RUN rm -rf  /etc/apt/sources.list.d/  && apt update
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    supervisor \
+    iputils-ping \
+    wget \
+    zsh \
+    build-essential \
+    cmake \
+    git \
+    curl \
+    vim \
+    ca-certificates \
+    libjpeg-dev \
+    zip \
+    unzip \
+    libpng-dev \
+    openssh-server \
+    autojump \
+    language-pack-zh-hans \
+    ttf-wqy-zenhei \
+    libgl1-mesa-glx  \
+    libglib2.0-0 \
+    locales &&\
+    rm -rf /var/lib/apt/lists/*
+
+
+RUN locale-gen zh_CN.UTF-8
+RUN dpkg-reconfigure locales
+
+
+CMD ["supervisord", "-n"]
+
+FROM builder as builder1
+
+ENV PYTHON_VERSION 3
+RUN chsh -s `which zsh`
+RUN curl -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda${PYTHON_VERSION}-latest-Linux-x86_64.sh  && \
+    chmod +x ~/miniconda.sh && \
+    ~/miniconda.sh -b -p /opt/conda && \
+    rm ~/miniconda.sh
+
+RUN ln /opt/conda/bin/conda /usr/local/bin/conda
+RUN conda init zsh
+RUN conda install mamba -n base -c conda-forge
+RUN ln /opt/conda/bin/mamba /usr/local/bin/mamba && mamba init zsh
+
+
+
+FROM builder1 as builder2
+
+RUN apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*
+RUN mkdir /var/run/sshd
+RUN echo 'root:root' | chpasswd
+RUN sed -i 's/.*PermitRootLogin .*/PermitRootLogin yes/' /etc/ssh/sshd_config
+# SSH login fix. Otherwise user is kicked off after login
+RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
+
+RUN echo "\
+[program:sshd] \n\
+command=/usr/sbin/sshd -D\n\
+autorestart=True\n\
+autostart=True\n\
+redirect_stderr = true\n\
+" > /etc/supervisor/conf.d/sshd.conf
+
+EXPOSE 22
+
+FROM builder2 as builder3
+
+WORKDIR /workspace
+ADD environment.yml /environment.yml
+RUN mamba update -n base -c defaults conda -y && mamba env create -f /environment.yml && rm -rf /root/.cache
+
+# RUN /opt/conda/envs/py38/bin/python -m ipykernel install --name py38 --display-name "py38"
+# RUN echo "c.MultiKernelManager.default_kernel_name = 'py38'">>/root/.jupyter/jupyter_notebook_config.py
+RUN echo "\
+[program:be]\n\
+directory=/workspace\n\
+command=/opt/conda/envs/py38/bin/gunicorn server:app --workers 1 --worker-class=uvicorn.workers.UvicornWorker  --bind 0.0.0.0:8080 --reload \n\
+autorestart=true\n\
+startretries=0\n\
+redirect_stderr=true\n\
+stdout_logfile=/var/log/be.log\n\
+stdout_logfile_maxbytes=0\n\
+" > /etc/supervisor/conf.d/be.conf
+
+ARG VERSION
+ENV USE_CUDA $VERSION
+Add . /workspace
+EXPOSE 8080
+
+
+
+# RUN mamba install -y jupyterlab -n base && mamba init zsh
+# RUN /opt/conda/bin/jupyter notebook --generate-config && \
+#     echo "c.NotebookApp.password='argon2:\$argon2id\$v=19\$m=10240,t=10,p=8\$+zIUCF9Uk2FiCHlV8njX5A\$I5Mm/64DORArcXYTXWRVng'">>/root/.jupyter/jupyter_notebook_config.py
+
+
+# RUN mkdir -p /data && echo "\
+# [program:jupyter]\n\
+# directory=/data\n\
+# command=/opt/conda/bin/jupyter lab --ip 0.0.0.0 --port 8888 --allow-root --no-browser \n\
+# autorestart=true\n\
+# startretries=0\n\
+# redirect_stderr=true\n\
+# stdout_logfile=/dev/stdout\n\
+# stdout_logfile_maxbytes=0\n\
+# " > /etc/supervisor/conf.d/jupyter.conf
+
+# EXPOSE 8888

+ 21 - 0
docker-compose.yml

@@ -0,0 +1,21 @@
+version: '2'
+services:
+  cet:
+    hostname: cet
+    container_name: cet
+    restart: always
+    image: registry.cn-hangzhou.aliyuncs.com/sxtest/cet:gpu
+    privileged: true
+    ipc: host
+    tty: true
+    working_dir: /workspace
+    ports:
+      - '18050:8080'
+      - '22522:22'
+    volumes:
+      - ./:/workspace
+#    deploy:
+#      resources:
+#        reservations:
+#          devices:
+#            - capabilities: [gpu]

+ 29 - 0
environment.yml

@@ -0,0 +1,29 @@
+name: py38
+channels:
+  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ # Anocanda清华镜像
+  - defaults
+  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
+dependencies:
+  - python=3.8
+  - ipython
+  - pip
+  - pip:
+      - cmake
+      - opencv-python
+      - cython
+      - paddleocr
+      - paddlehub
+      - fastapi
+      - uvicorn
+      - zhon
+      - pytest
+      - jinja2
+      - aiofiles
+      - python-multipart
+      - requests
+      - cpca
+      - gunicorn
+      - -i https://mirror.baidu.com/pypi/simple
+      - paddlepaddle  # gpu==2.3.0.post110
+      - -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
+prefix: /opt/conda/envs/py38

BIN
images/.DS_Store


BIN
images/8-16/.DS_Store


BIN
images/8-16/08_img.jpg


BIN
images/8-16/12_img.jpg


BIN
images/8-16/14_img.jpg


BIN
images/8-16/18_img.jpg


BIN
images/8-16/19_img.jpg


BIN
images/8-16/20_img.jpg


BIN
images/8-16/affa.jpg


BIN
images/8-16/fasdfas.jpg


BIN
images/8-16/hh.jpg


BIN
images/8-16/wgw.jpg


BIN
images/cet/.DS_Store


BIN
images/cet/01_img.jpg


+ 10 - 0
images/cet/01_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "武文斌",
+    "id": "642226199704273215",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2021年12月",
+    "score": "474"
+}

BIN
images/cet/02_img.jpg


+ 10 - 0
images/cet/02_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "纪春",
+    "id": "230904199909090529",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2018年6月",
+    "score": "472"
+}

BIN
images/cet/03_img.jpg


+ 10 - 0
images/cet/03_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "姚帆",
+    "id": "142702199903124229",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2019年12月",
+    "score": "471"
+}

BIN
images/cet/04_img.jpg


+ 10 - 0
images/cet/04_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "林雄",
+    "id": "431121199807158036",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2017年12月",
+    "score": "490"
+}

BIN
images/cet/05_img.jpg


+ 10 - 0
images/cet/05_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "李然琦",
+    "id": "370502199709144023",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2016年12月",
+    "score": "438"
+}

BIN
images/cet/07_img.jpg


+ 10 - 0
images/cet/07_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "吴晓虎",
+    "id": "150426199902113039",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2018年12月",
+    "score": "425"
+}

BIN
images/cet/08_img.jpg


+ 10 - 0
images/cet/08_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "张鑫",
+    "id": "140227199809282317",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2021年6月",
+    "score": "445"
+}

BIN
images/cet/09_img.jpg


+ 10 - 0
images/cet/09_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "张鹏远",
+    "id": "150203199812150615",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2020年12月",
+    "score": "437"
+}

BIN
images/cet/10_img.jpg


+ 10 - 0
images/cet/10_img.json

@@ -0,0 +1,10 @@
+{
+    "card_type": "0",
+    "orientation": 0,
+    "name": "袁湘粤",
+    "id": "43052419961021177X",
+    "language": "英语",
+    "level": "CET4",
+    "exam_time": "2020年12月",
+    "score": "448"
+}

BIN
images/tem/00288070A20资格证书001(1).jpg


BIN
images/tem/10022276A20资格证书003(1).jpg


BIN
images/tem/10022507A20资格证书001(1).jpg


+ 11 - 0
run.py

@@ -0,0 +1,11 @@
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 122 - 0
server.py

@@ -0,0 +1,122 @@
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+from paddleocr import PaddleOCR
+
+from core.direction import AngleDetector
+from sx_utils.sximage import *
+from sx_utils.sxtime import sxtimeit
+from sx_utils.sxweb import web_try
+from core.ocr import CetOcr
+import os
+
+# 导入一些包
+
+
+app = FastAPI()
+
+origins = ["*"]
+
+# CORS 跨源资源共享
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=origins,
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+# templates = Jinja2Templates(directory='templates')
+
+use_gpu = False
+if os.getenv('USE_CUDA') == 'gpu':
+    use_gpu = True
+
+print(f'use gpu: {use_gpu}')
+
+# 初始化ocr模型和后处理模型
+# 分类
+# ocr = PaddleOCR(use_angle_cls=True,
+#                 # 方向
+#                 rec_model_dir="./idcard_rec_infer/",
+#                 det_model_dir="./idcard_det_infer/",
+#                 cls_model_dir="idcard_cls_infer",
+#                 # 识别
+#                 rec_algorithm='CRNN',
+#                 ocr_version='PP-OCRv2',
+#                 # 中文字典
+#                 rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
+#                 use_gpu=use_gpu,
+#                 # 预训练-->效果不明显
+#                 # 网络不够大、不够深
+#                 # 数据集普遍较小,batch size普遍较小
+#                 warmup=True)
+# ocr = PaddleOCR(use_angle_cls=True,
+#                 use_gpu=use_gpu)
+
+
+ocr = PaddleOCR(use_angle_cls=True,
+                use_gpu=use_gpu,
+                det_db_unclip_ratio=2.5,
+                det_db_thresh=0.1,
+                det_db_box_thresh=0.3,
+                warmup=True)
+#
+# ocr = PaddleOCR(use_angle_cls=True,
+#                 rec_model_dir='./ch_ppocr_server_v2.0_rec_infer',
+#                 det_model_dir='./ch_ppocr_server_v2.0_det_infer',
+#                 cls_model_dir='./idcard_cls_infer',
+#                 ocr_version='PP-OCRv2',
+#                 rec_algorithm='CRNN',
+#                 use_gpu=use_gpu,
+#                 det_db_unclip_ratio=2.5,
+#                 det_db_thresh=0.1,
+#                 det_db_box_thresh=0.3,
+#                 warmup=True)
+
+
+# 初始化 角度检测器 对象
+ad = AngleDetector(ocr)
+
+# 初始化 身份证ocr识别 对象
+m = CetOcr(ocr, ad)
+
+
+# Get 健康检查
+@app.get("/ping")
+def ping():
+    return "pong!"
+
+
+# 解析传入的 json对象
+class CetInfo(BaseModel):
+    image: str
+
+
+# /ocr_system/bankcard 银行卡
+# /ocr_system/regbook 户口本
+# /ocr_system/schoolcert 学信网
+
+# Post 接口
+# 计算耗时
+# 异常处理
+@app.post("/ocr_system/cet")
+@sxtimeit
+@web_try()
+# 传入=> base64码 -> np
+# 返回=> 检测到到结果 -> (conf, angle, parser, image_type)
+def cet(request: Request, cer: CetInfo):
+    image = base64_to_np(cer.image)
+    return m.predict(image)
+
+
+if __name__ == '__main__':
+    import uvicorn
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0')
+    parser.add_argument('--port', default=8080)
+    opt = parser.parse_args()
+
+    app_str = 'server:app'  # make the app string equal to whatever the name of this file is
+    uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)

+ 0 - 0
sx_utils/__init__.py


+ 12 - 0
sx_utils/sximage.py

@@ -0,0 +1,12 @@
+from base64 import b64decode
+import numpy as np
+import cv2
+
+
+def base64_to_np(img_data):
+    color_image_flag = 1
+    img_data = img_data.split(',', 1)[-1]
+    # b64decode -> base64图片解码
+    # np.fromstring -> 从字符串中的文本数据初始化的新一维数组
+    # numpy.fromstring(string, dtype=float, count=-1, sep='')
+    return cv2.imdecode(np.fromstring(b64decode(img_data), dtype=np.uint8), color_image_flag)

+ 125 - 0
sx_utils/sxtime.py

@@ -0,0 +1,125 @@
+# coding=utf-8
+# Powered by SoaringNova Technology Company
+import errno
+import os
+import signal
+import time
+from collections import defaultdict
+from functools import wraps
+
+timer_counts = defaultdict(int)
+
+
+class SXTimeoutError(Exception):
+    pass
+
+
+def sxtimeout(seconds=10, error_message=os.strerror(errno.ETIME)):
+    def decorator(func):
+        def _handle_timeout(signum, frame):
+            raise SXTimeoutError(error_message)
+
+        def wrapper(*args, **kwargs):
+            signal.signal(signal.SIGALRM, _handle_timeout)
+            signal.alarm(seconds)
+            try:
+                result = func(*args, **kwargs)
+            finally:
+                signal.alarm(0)
+            return result
+
+        return wraps(func)(wrapper)
+
+    return decorator
+
+
+class SXTIMELIMIT:
+    def __init__(self, limit_time=0):
+        self.st = None
+        self.et = None
+        self.limit_time = limit_time
+
+    def __enter__(self):
+        self.st = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.et = time.time()
+        dt = self.limit_time - (self.et - self.st) * 1000
+        if dt > 0: time.sleep(float(dt) / 1000)
+
+
+class SXTIMER:
+    total_time = {}  # type: dict
+
+    def __init__(self, tag='', enable_total=False, threshold_ms=0):
+        self.st = None
+        self.et = None
+        self.tag = tag
+        # self.tag = tag if not hasattr(g,'request_id') else '{} {}'.format(getattr(g,'request_id'),tag)
+
+        self.thr = threshold_ms
+        self.enable_total = enable_total
+        if self.enable_total:
+            if self.tag not in self.total_time.keys():
+                self.total_time[self.tag] = []
+
+    def __enter__(self):
+        self.st = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.et = time.time()
+        dt = (self.et - self.st) * 1000
+        if self.enable_total:
+            self.total_time[self.tag].append(dt)
+
+        if dt > self.thr:
+            print("{}: {}s".format(self.tag, round(dt / 1000, 4)))
+
+    @staticmethod
+    def output():
+        for k, v in SXTIMER.total_time.items():
+            print('{} : {}s, avg{}s'.format(k, round(sum(v) / 1000, 2), round(sum(v) / len(v) / 1000, 2)))
+
+
+def sxtimeit(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        st = time.time()
+        ret = func(*args, **kwargs)
+        dt = time.time() - st
+        endpoint = '{}.{}'.format(func.__module__, func.__name__)
+        timer_counts[endpoint] += 1
+        print('{}[{}] finished, exec {}s'.format(endpoint, '%05d' % timer_counts[endpoint], round(dt, 4)))
+        return ret
+
+    return wrapper  # 返回
+
+
+# def sxtimeit(func):
+#     @wraps(func)
+#     def wrapper(*args, **kwargs):
+#         endpoint = '{}.{}'.format(func.__module__, func.__name__)
+#         setattr(g,'request_id','{}[{}]'.format(endpoint,'%05d' % timer_counts[endpoint]))
+#         timer_counts[endpoint] += 1
+#         st = time.time()
+#         ret = func(*args, **kwargs)
+#         dt = time.time() - st
+#         print ('{} finished, exec {}s'.format(getattr(g,'request_id'), round(dt, 4)))
+#         return ret
+#
+#     return wrapper  # 返回
+
+def t2date(t):
+    import datetime
+    date = datetime.datetime.fromtimestamp(t)
+    return '{}_{}_{}_{}:{}:{}'.format(date.year, date.month, date.day, date.hour, date.minute, date.second)
+
+
+def day_begin(t):
+    dsecs = 24 * 3600
+    return (int(t) + 8 * 3600) // dsecs * dsecs - 8 * 3600
+
+
+def hour_begin(t):
+    hsecs = 3600
+    return (int(t) + 8 * 3600) // hsecs * hsecs - 8 * 3600

+ 43 - 0
sx_utils/sxweb.py

@@ -0,0 +1,43 @@
+import traceback
+import numpy as np
+from decorator import decorator
+from fastapi.responses import JSONResponse
+
+
+def json_compatible(data):
+    if isinstance(data,dict):
+        return {k:json_compatible(v) for k,v in data.items()}
+    if isinstance(data,bytes):
+        return str(data)
+    if isinstance(data,np.ndarray):
+        return data.tolist()
+    return data
+
+def web_try(exception_ret=None):
+    @decorator
+    def f(func, *args, **kwargs):
+        error_code = "000"
+        ret = None
+        msg = ''
+        try:
+            ret = func(*args, **kwargs)
+        except Exception as e:
+            msg = traceback.format_exc()
+            if len(e.args) > 0 and isinstance(e.args[0], int):
+                error_code = e.args[0]
+            else:
+                error_code = "101"
+            print('--------------------------------')
+            print ('Get Exception in web try :( \n{}\n'.format(msg))
+            print('--------------------------------')
+            if callable(exception_ret):
+                ret = exception_ret()
+            else:
+                ret = exception_ret
+        finally:
+            if ret is not None and isinstance(ret, JSONResponse):
+                return ret
+            return json_compatible({"status": error_code,
+                                    "result": ret,
+                                    "msg": msg.split('\n')[-2] if msg is not '' else msg})
+    return f

BIN
testing/.DS_Store


+ 0 - 0
testing/__init__.py


+ 180 - 0
testing/true_test.py

@@ -0,0 +1,180 @@
+import unittest
+from dataclasses import dataclass
+from pathlib import Path
+
+from testing.utils import *
+
+
+@dataclass
+class ResultItem:
+    status: str
+    orientation: int
+    name: str
+    id: str
+    ethnicity: str
+    gender: str
+    birthday: str
+    address: str
+
+
+class TestIdCardAddress(unittest.TestCase):
+    def _helper(self, image_path, item: ResultItem):
+        root = Path(__file__).parent
+        image_path = str(root / image_path)
+        r = send_request(image_path, '0')
+        self.assertEqual(item.status, r['status'], f'{image_path} status case error')
+        self.assertEqual(item.orientation, r['result']['orientation'], f'{image_path} orientation case error')
+        self.assertEqual(item.name, r['result']['name']['text'], f'{image_path} name case error')
+        self.assertEqual(item.id, r['result']['id']['text'], f'{image_path} id case error')
+        self.assertEqual(item.ethnicity, r['result']['ethnicity']['text'], f'{image_path} ethnicity case error')
+        self.assertEqual(item.gender, r['result']['gender']['text'], f'{image_path} gender case error')
+        self.assertEqual(item.birthday, r['result']['birthday']['text'], f'{image_path} birthday case error')
+        self.assertEqual(item.address, r['result']['address']['text'], f'{image_path} address case error')
+
+    def test_01(self):
+        image_path = '../images/ture/01.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='宋宝磊',
+                                            id='150430199905051616',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1999年05月05日',
+                                            address='内蒙古赤峰市敖汉旗四家子镇林家地村唐坊沟七组'))
+
+    def test_02(self):
+        image_path = '../images/ture/02.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=2,
+                                            name='方势文',
+                                            id='360428199610220096',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1996年10月22日',
+                                            address='江西省九江市都昌县都昌镇沿湖路238号'))
+
+    def test_03(self):
+        image_path = '../images/ture/03.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='彭贤端',
+                                            id='450922199412083669',
+                                            ethnicity='汉',
+                                            gender='女',
+                                            birthday='1994年12月08日',
+                                            address='广西陆川县清湖镇塘寨村新屋队62号'))
+
+    def test_04(self):
+        image_path = '../images/ture/04.png'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='王晓凤',
+                                            id='150921199910021527',
+                                            ethnicity='汉',
+                                            gender='女',
+                                            birthday='1999年10月02日',
+                                            address='呼和浩特市新城区赛马场北路城市维也纳13号楼2单元1303号'))
+
+    def test_05(self):
+        image_path = '../images/ture/05.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='任学东',
+                                            id='152630198501117517',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1985年01月11日',
+                                            address='内蒙古乌兰察布市察哈尔右翼前旗巴音塔拉镇谢家村22户'))
+
+    def test_06(self):
+        image_path = '../images/ture/06.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='田浩',
+                                            id='640221199702060618',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1997年02月06日',
+                                            address='宁夏平罗县黄渠桥镇侯家梁村二队29'))
+
+    def test_07_0(self):
+        image_path = '../images/ture/07_0.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='左翔宇',
+                                            id='220204199910123017',
+                                            ethnicity='蒙古',
+                                            gender='男',
+                                            birthday='1999年10月12日',
+                                            address='吉林省吉林市船营区鑫安小区2-6-60号'))
+
+    def test_07_90(self):
+        image_path = '../images/ture/07_90.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=1,
+                                            name='左翔宇',
+                                            id='220204199910123017',
+                                            ethnicity='蒙古',
+                                            gender='男',
+                                            birthday='1999年10月12日',
+                                            address='吉林省吉林市船营区鑫安小区2-6-60号'))
+
+    def test_07_180(self):
+        image_path = '../images/ture/07_180.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=2,
+                                            name='左翔宇',
+                                            id='220204199910123017',
+                                            ethnicity='蒙古',
+                                            gender='男',
+                                            birthday='1999年10月12日',
+                                            address='吉林省吉林市船营区鑫安小区2-6-60号'))
+
+    def test_08_0(self):
+        image_path = '../images/ture/08_0.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='张开天',
+                                            id='622301199710247376',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1997年10月24日',
+                                            address='甘肃省武威市凉州区九墩乡平乐村四组23号'))
+
+    def test_08_180(self):
+        image_path = '../images/ture/08_180.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=2,
+                                            name='张开天',
+                                            id='622301199710247376',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1997年10月24日',
+                                            address='甘肃省武威市凉州区九墩乡平乐村四组23号'))
+
+    def test_09(self):
+        image_path = '../images/ture/09.jpeg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='韩毅',
+                                            id='211102199912162010',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1999年12月16日',
+                                            address=
+                                            '辽宁省盘锦市双台子区胜利街道团结社区15区40号'))
+
+    def test_10(self):
+        image_path = '../images/ture/10.jpg'
+        self._helper(image_path, ResultItem(status='000',
+                                            orientation=0,
+                                            name='曾令权',
+                                            id='432524199306221414',
+                                            ethnicity='汉',
+                                            gender='男',
+                                            birthday='1993年06月22日',
+                                            address='湖南省新化县维山乡官庄村陈家冲组21号'))
+
+
+if __name__ == '__main__':
+    unittest.main()

+ 46 - 0
testing/utils.py

@@ -0,0 +1,46 @@
+import base64
+from dataclasses import dataclass
+from typing import Optional
+
+import cv2
+import numpy as np
+import requests
+
+url = 'http://192.168.199.249:18080'
+
+
+def send_request(image_path, image_type):
+    with open(image_path, 'rb') as f:
+        img_str: str = base64.encodebytes(f.read()).decode('utf-8')
+        r = requests.post(f'{url}/ocr_system/idcard', json={'image': img_str, 'image_type': image_type})
+        print(r.json())
+        return r.json()
+
+
+def send_request_cv2(image_path, image_type, rotate=None):
+    img = cv2.imread(image_path)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    if rotate is not None:
+        img = cv2.rotate(img, rotate)
+    _, img = cv2.imencode('.jpg', img)
+    img_str = base64.b64encode(img).decode('utf-8')
+    r = requests.post(f'{url}/ocr_system/idcard', json={'image': img_str, 'image_type': image_type})
+    print(r.json())
+    return r.json()
+
+@dataclass
+class ResultItem:
+    # status: str
+    orientation: int
+    name: str
+    id: str
+    ethnicity: str
+    gender: str
+    birthday: str
+    address: str
+    card_type: Optional[str]
+    address_province: Optional[str]
+    address_city: Optional[str]
+    address_region: Optional[str]
+    address_detail: Optional[str]
+    expire_date: Optional[str]