Zhang Li 2 years ago
parent
commit
4afce8b544
10 changed files with 227 additions and 232 deletions
  1. 0 65
      core/back.py
  2. 122 52
      core/idcrad.py
  3. 72 0
      core/ocr.py
  4. BIN
      front-270-1.png
  5. BIN
      front-270.png
  6. BIN
      front-90-1.png
  7. BIN
      front-90.png
  8. 17 4
      main.py
  9. 8 27
      main2.py
  10. 8 84
      server.py

+ 0 - 65
core/back.py

@@ -1,65 +0,0 @@
-import re
-import json
-import string
-
-
-class IdCardStraight:
-    """
-    """
-
-    def __init__(self, result):
-        self.result = [
-            i.replace(" ", "").translate(str.maketrans("", "", string.punctuation))
-            for i in result
-        ]
-        self.out = {"Data": {"FrontResult": {}}}
-        self.res = self.out["Data"]["FrontResult"]
-        self.res["data"] = ""
-        self.res["isauthor"] = ""
-
-    # def IS_author(self):
-    #     """
-    #     签发机关
-    #     """
-
-    # def
-
-    def isauthor(self):
-        """
-        签发机关
-        """
-        addString = []
-        # for i in range(len(self.result)):
-        txt = self.result[2]
-        txt = txt.replace("中华人民共和国", "").replace("居民身份证", "").replace("签发机关", "")
-        print(txt)
-        # txt2=txt.split("签发日期")
-        print('--------')
-        print(txt)
-        # addString.insert(0, txt.split("有效期限")[-1])
-
-        addString.append(txt)
-        # self.result= "temp"
-        if len(addString) > 0:
-            self.res["isauthor"] = "".join(addString)
-
-    def data(self):
-        """
-        签发机关
-        """
-        addString = []
-        txt = self.result[3]
-        txt = txt.replace("中华人民共和国", "").replace("居民身份证", "").replace("有效期限", "")
-        addString.append(txt)
-        # self.result= "temp"
-        if len(addString) > 0:
-            self.res["data"] = "".join(addString)
-
-    def run(self):
-        print(self)
-        self.data()
-        # self.national()
-        # self.birth_no()
-        self.isauthor()
-        print(self.out)
-        return json.dumps(self.out)

+ 122 - 52
core/idcrad.py

@@ -1,34 +1,54 @@
 import re
-import json
 import string
-
-
-class IdCardStraight:
+from dataclasses import dataclass
+from collections import defaultdict
+import numpy as np
+import cpca
+
+
+@dataclass
+class RecItem:
+    text: str = ''
+    confidence: float = 0.
+
+    def to_dict(self):
+        return {"text": self.text, "confidence": self.confidence}
+
+
+class Parser(object):
+    def __init__(self, txts, confs):
+        self.result = txts
+        self.confs = confs
+        assert len(self.result) == len(self.confs), 'result and confs do not match'
+        self.res = defaultdict(RecItem)
+        self.res["Name"] = RecItem()
+        self.res["IDNumber"] = RecItem()
+        self.res["Address"] = RecItem()
+        self.res["Gender"] = RecItem()
+        self.res["Nationality"] = RecItem()
+        self.res["Birth"] = RecItem()
+        self.res["expire_date"] = RecItem()
+
+    def parse(self):
+        return self.res
+
+    @property
+    def confidence(self):
+        return 0.
+
+class FrontParser(Parser):
     """
     """
 
-    def __init__(self, result):
+    def __init__(self, txts, confs):
+        Parser.__init__(self, txts, confs)
         self.result = [
             i.replace(" ", "").translate(str.maketrans("", "", string.punctuation))
-            for i in result
+            for i in txts
         ]
-        self.out = {"Data": {"FrontResult": {}}}
-        self.res = self.out["Data"]["FrontResult"]
-        self.res["Name"] = ""
-        self.res["IDNumber"] = ""
-        self.res["Address"] = ""
-        self.res["Gender"] = ""
-        self.res["Nationality"] = ""
-        self.res["year"] = ""
-        # self.res["Isauthority"]=""
-        # self.res["Effdata"]=""
-
-    # def IS_author(self):
-    #     """
-    #     签发机关
-    #     """
-
-    def year(self):
+        assert len(self.result) == len(self.confs), 'result and confs do not match'
+
+    def birth(self):
         addString = []
         for i in range(len(self.result)):
             txt = self.result[i]
@@ -36,16 +56,11 @@ class IdCardStraight:
                 # txt = txt.replace("出生", "")
                 txt = txt.split('生')[-1]
                 addString.append(txt.strip())
-                self.res["year"] = "".join(addString)
-
-    #         break
-    # print(',,,,')
-    # print(self.result)
-    # txt = txt.replace("出生", "")
-    # addString.append(txt)
-    # print(txt)
-    # self.res["year"] = "".join(addString)
-    def birth_no(self):
+                self.res["Birth"].text = "".join(addString)
+                self.res["Birth"].confidence = self.confs[i]
+                break
+
+    def card_no(self):
         """
         身份证号码
         """
@@ -60,35 +75,42 @@ class IdCardStraight:
 
             if len(res) > 0:
                 if len(res[0]) == 18:
-                    self.res["IDNumber"] = res[0].replace("号码", "")
-                    self.res["Gender"] = "男" if int(res[0][16]) % 2 else "女"
+                    self.res["IDNumber"].text = res[0].replace("号码", "")
+                    self.res["IDNumber"].confidence = self.confs[i]
+                    self.res["Gender"].text = "男" if int(res[0][16]) % 2 else "女"
+                    self.res["Gender"].confidence = self.confs[i]
                 break
 
     def full_name(self):
         """
         身份证姓名
         """
-        #  print(self)
         for i in range(len(self.result)):
             txt = self.result[i]
             if ("姓名" or "名" in txt) and len(txt) > 2:
                 res = re.findall("名[\u4e00-\u9fa5]{1,4}", txt)
                 if len(res) > 0:
-                    self.res["Name"] = res[0].split("名")[-1]
+                    self.res["Name"].text = res[0].split("名")[-1]
+                    self.res["Name"].confidence = self.confs[i]
                     self.result[i] = "temp"  # 避免身份证姓名对地址造成干扰
                     break
 
-    def sex(self):
+    def gender(self):
         """
         性别女民族汉
         """
+        if len(self.res["Gender"].text) != 0: return
         for i in range(len(self.result)):
             txt = self.result[i]
             if "男" in txt:
-                self.res["Gender"] = "男"
+                self.res["Gender"].text = "男"
+                self.res["Gender"].confidence = self.confs[i]
+                break
 
-            elif "女" in txt:
-                self.res["Gender"] = "女"
+            if "女" in txt:
+                self.res["Gender"].text = "女"
+                self.res["Gender"].confidence = self.confs[i]
+                break
 
     def national(self):
         # 性别女民族汉
@@ -97,7 +119,8 @@ class IdCardStraight:
             res = re.findall(".*民族[\u4e00-\u9fa5]+", txt)
 
             if len(res) > 0:
-                self.res["Nationality"] = res[0].split("族")[-1]
+                self.res["Nationality"].text = res[0].split("族")[-1]
+                self.res["Nationality"].confidence = self.confs[i]
                 break
 
     def address(self):
@@ -105,6 +128,7 @@ class IdCardStraight:
         身份证地址
         """
         addString = []
+        conf = []
         for i in range(len(self.result)):
             txt = self.result[i]
             txt = txt.replace("号码", "")
@@ -132,13 +156,31 @@ class IdCardStraight:
                     addString.insert(0, txt.split("址")[-1])
                 else:
                     addString.append(txt)
-
+                conf.append(self.confs[i])
                 self.result[i] = "temp"
         # print(addString)
         if len(addString) > 0:
-            self.res["Address"] = "".join(addString)
-        else:
-            self.res["Address"] = ""
+            self.res["Address"].text = "".join(addString)
+            self.res["Address"].confidence = np.mean(conf)
+        print(f'addr: {self.res["Address"]}')
+
+    def split_addr(self):
+        if self.res["Address"].text:
+            conf = self.res["Address"].confidence
+            print('split_addr', self.res["Address"].text)
+            df = cpca.transform([self.res["Address"].text])
+            print(df)
+
+            province = df.iloc[0, 0]
+            city = df.iloc[0, 1]
+            region = df.iloc[0, 2]
+            detail = df.iloc[0, 3]
+            print(f'pronvince: {province}, city: {city}, region: {region}, detail: {detail}')
+            self.res["address_province"] = RecItem(province, conf)
+            self.res["address_city"] = RecItem(city, conf)
+            self.res["address_region"] = RecItem(region, conf)
+            self.res["address_detail"] = RecItem(detail, conf)
+
 
     def predict_name(self):
         """
@@ -163,12 +205,40 @@ class IdCardStraight:
                             self.res["Name"] = result[0]
                             break
 
-    def run(self):
+    @property
+    def confidence(self):
+        return np.mean(self.confs)
+
+    def parse(self):
         self.full_name()
         self.national()
-        self.birth_no()
+        self.card_no()
         self.address()
-        self.predict_name()
-        self.year()
-        print(self.out)
-        return self.out
+        self.split_addr()
+        # self.predict_name()
+        self.birth()
+        self.gender()
+        return self.res
+
+
+
+class BackParser(Parser):
+    def __init__(self, txts, confs):
+        Parser.__init__(self, txts, confs)
+
+
+    def expire_date(self):
+        for txt, conf in zip(self.result, self.confs):
+            print(txt)
+            res = re.findall('\d{4}\.\d{2}\.\d{2}\-\d{4}\.\d{2}\.\d{2}', txt)
+            print(res)
+            if res:
+                self.res["expire_date"] = RecItem(res[0], conf)
+
+    @property
+    def confidence(self):
+        return np.mean(self.confs)
+
+    def parse(self):
+        self.expire_date()
+        return self.res

+ 72 - 0
core/ocr.py

@@ -0,0 +1,72 @@
+from dataclasses import dataclass
+from core.idcrad import *
+from core.direction import *
+import numpy as np
+from paddleocr import PaddleOCR
+
+
+@dataclass
+class IdCardOcr:
+    ocr: PaddleOCR
+    image: np.ndarray
+    image_type: str = '0'
+
+    def predict(self):
+        txts, confs, angle = self._ocr()
+        if int(self.image_type) == 0:
+            parser = FrontParser(txts, confs)
+        elif int(self.image_type) == 1:
+            parser = BackParser(txts, confs)
+
+        return self._post_process(angle, parser, self.image_type)
+
+    def _align_image(self, image):
+        angle = detect_angle(image)
+        print(angle)  # 逆时针
+        if angle == 180:
+            image = cv2.rotate(image, cv2.ROTATE_180)
+        if angle == 90:
+            image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
+        if angle == 270:
+            image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
+        return image, angle
+
+    def _ocr(self):
+        image = self.image
+        image, angle = self._align_image(image)
+        # 获取模型检测结果
+        result = self.ocr.ocr(image, cls=True)
+        print("------------------")
+        print(result)
+        if not result:
+            return None
+        confs = [line[1][1] for line in result]
+
+        # 将检测到的文字放到一个列表中
+        txts = [line[1][0] for line in result]
+        print("......................................")
+        print(txts)
+        print("......................................")
+        return txts, confs, angle
+
+
+    def _post_process(self, angle: int, parser: Parser, image_type: str):
+        content = parser.parse()
+        conf = parser.confidence
+
+        return {
+            "confidence": conf,
+            "card_type": image_type,
+            "orientation": angle // 90,
+            "name": content["Name"].to_dict(),
+            "id": content["IDNumber"].to_dict(),
+            "ethnicity": content["Nationality"].to_dict(),
+            "gender": content["Gender"].to_dict(),
+            "birthday": content["Birth"].to_dict(),
+            "address_province": content["address_province"].to_dict(),
+            "address_city": content["address_city"].to_dict(),
+            "address_region": content["address_region"].to_dict(),
+            "address_detail": content["address_detail"].to_dict(),
+            "expire_date": content["expire_date"].to_dict()
+        }
+

BIN
front-270-1.png


BIN
front-270.png


BIN
front-90-1.png


BIN
front-90.png


+ 17 - 4
main.py

@@ -1,5 +1,6 @@
+import cv2
 from paddleocr import PaddleOCR
-from core.idcrad import IdCardStraight
+from core.idcrad import FrontParser
 from core.direction import *
 
 # 初始化ocr模型和后处理模型
@@ -10,11 +11,20 @@ ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./idcard_rec_infer/",
                 rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch", use_gpu=False)
 
 # 定义文件路径
-img_path = "front-180.png"
+img_path = "front-270.png"
 image = cv2.imread(img_path)
+# 逆时针
 angle = detect_angle(image)
 print(angle)
 
+if angle == 180:
+    image = cv2.rotate(image, cv2.ROTATE_180)
+if angle == 90:
+    image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
+    cv2.imwrite('front-90-1.png', image)
+if angle == 270:
+    image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
+    cv2.imwrite('front-270-1.png', image)
 
 
 # 获取模型检测结果
@@ -23,9 +33,12 @@ print("------------------")
 print(result)
 # 将检测到的文字放到一个列表中
 txts = [line[1][0] for line in result]
+confs = [line[1][1] for line in result]
 print("......................................")
 print(txts)
+print(confs)
 print("......................................")
 # 将结果送入到后处理模型中
-postprocessing = IdCardStraight(txts)
-id_result = postprocessing.run()
+postprocessing = FrontParser(txts, confs)
+parse_result = postprocessing.parse()
+print(parse_result)

+ 8 - 27
main2.py

@@ -1,7 +1,7 @@
 import os
 from paddleocr import PaddleOCR
-from core.back import IdCardStraight
-from core.direction  import *
+from core.idcrad import BackParser
+from core.direction import *
 import json
 
 # 初始化ocr模型和后处理模型
@@ -13,7 +13,7 @@ ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./idcard_rec_infer/",
 
 # 定义文件路径
 img_path = "back.jpeg"
-image=cv2.imread(img_path)
+image = cv2.imread(img_path)
 angle = detect_angle(image)
 print("------------------")
 print(angle)
@@ -21,34 +21,15 @@ print("------------------")
 
 # 获取模型检测结果
 result = ocr.ocr(img_path, cls=True)
-scores = [line[1][1] for line in result]
-sc=[]
-for i in range(2,4):
-    sc.append(scores[i])
-#scores = [line[1][1] for line in result]
-print("------------------")
-scores2=sum(scores)/len(scores)
-print(sc)
-print(scores2)
-print("------------------")
+confs = [line[1][1] for line in result]
+
 # 将检测到的文字放到一个列表中
 txts = [line[1][0] for line in result]
 print("......................................")
 print(txts)
 print("......................................")
 # 将结果送入到后处理模型中
-postprocessing = IdCardStraight(txts)
-id_result = postprocessing.run()
-result=id_result.encode('utf-8').decode('unicode_escape')
-result = json.loads(result)
+parser = BackParser(txts, confs)
+res = parser.parse()
 
-data={}
-data["confidence"]="null"
-data["orientation"]="null"
-data["expiry_data"]={"text":"null","confidence":"null"}
-data["isauthor"]={"text":"null","confidence":"null"}
-data["confidence"]=scores2
-data["orientation"]=angle
-data["expiry_data"]={"text":result["Data"]["FrontResult"]["data"],"confidence":sc[0]}
-data["isauthor"]={"text":result["Data"]["FrontResult"]["isauthor"],"confidence":sc[1]}
-print(data)
+print(res)

+ 8 - 84
server.py

@@ -1,24 +1,16 @@
 from fastapi import FastAPI, Request
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.templating import Jinja2Templates
 from pydantic import BaseModel
 from paddleocr import PaddleOCR
-from core.idcrad import IdCardStraight
-from core.direction import *
-from base64 import b64decode
-import numpy as np
 from sx_utils.sximage import *
 from sx_utils.sxtime import sxtimeit
 from sx_utils.sxweb import web_try
-import cpca
-
+from core.ocr import IdCardOcr
 
 
 app = FastAPI()
 
-origins = [
-    "*"
-]
+origins = ["*"]
 
 app.add_middleware(
     CORSMiddleware,
@@ -31,8 +23,10 @@ app.add_middleware(
 
 
 # 初始化ocr模型和后处理模型
-ocr = PaddleOCR(use_angle_cls=True, rec_model_dir="./idcard_rec_infer/",
-                det_model_dir="./idcard_det_infer/", cls_model_dir="idcard_cls_infer",
+ocr = PaddleOCR(use_angle_cls=True,
+                rec_model_dir="./idcard_rec_infer/",
+                det_model_dir="./idcard_det_infer/",
+                cls_model_dir="idcard_cls_infer",
                 rec_algorithm='CRNN',
                 ocr_version='PP-OCRv2',
                 rec_char_dict_path="./ppocr_keys_v1.txt", lang="ch",
@@ -67,78 +61,8 @@ class IdCardInfo(BaseModel):
 @web_try()
 def idcard(request: Request, id_card: IdCardInfo):
     image = base64_to_np(id_card.image)
-    if int(id_card.image_type) == 0:
-        return _front(image, id_card.image_type)
-    elif int(id_card.image_type) == 1:
-       return _back(image, id_card.image_type)
-    else:
-        raise Exception('not implemented yet')
-
-def _back(image, image_type):
-    raise Exception('not implemented yet')
-
-def _front(image, image_type: str):
-    angle = detect_angle(image)
-    print(angle)
-    # 获取模型检测结果
-    result = ocr.ocr(image, cls=True)
-    print("------------------")
-    print(result)
-    if not result:
-        return None
-    scores = [line[1][1] for line in result]
-
-    score = sum(scores) / len(scores)
-    print("------------------")
-    print(scores)
-    print("------------------")
-
-    sc = []
-    for i in range(0, 6):
-        sc.append(scores[i])
-    sc.append(scores[i])
-    sc.append(scores[i])
-    sc.append(scores[i])
-
-    # 将检测到的文字放到一个列表中
-    txts = [line[1][0] for line in result]
-    print("......................................")
-    print(txts)
-    print("......................................")
-    # 将结果送入到后处理模型中
-    postprocessing = IdCardStraight(txts)
-    id_result = postprocessing.run()
-    content = id_result['Data']['FrontResult']
-
-    location_str = []
-    location_str.append(content["Address"])
-    print(location_str)
-    df = cpca.transform(location_str)
-    print(df)
-
-    province = df.iloc[0, 0]
-    city = df.iloc[0, 1]
-    region = df.iloc[0, 2]
-    detail = df.iloc[0, 3]
-    print(f'pronvince: {province}, city: {city}, region: {region}, detail: {detail}')
-    return {
-        "confidence": score,
-        "card_type": image_type,
-        "orientation": angle // 90,
-        "name": {'text': content['Name'], 'confidence': sc[0]},
-        "id": {'text': content['IDNumber'], 'confidence': sc[1]},
-        "ethnicity": {'text': content['Nationality'], 'confidence': sc[2]},
-        "gender": {'text': content['Gender'], 'confidence': sc[3]},
-        "birthday": {'text': content['year'], 'confidence': sc[4]},
-        "address_province": {'text': province, 'confidence': sc[5]},
-        "address_city": {'text': city, 'confidence': sc[6]},
-        "address_region": {'text': region, 'confidence': sc[7]},
-        "address_detail": {'text': detail, 'confidence': sc[8]},
-        "expire_date": {'text': '', 'confidence': 0}
-    }
-
-
-
+    m = IdCardOcr(ocr, image, id_card.image_type)
+    return m.predict()
 
 
 if __name__ == '__main__':