ocr.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from dataclasses import dataclass
  2. from core.parser import *
  3. from core.direction import *
  4. import numpy as np
  5. from paddleocr import PaddleOCR
  6. @dataclass
  7. class IdCardOcr:
  8. ocr: PaddleOCR
  9. def predict(self, image: np.ndarray, image_type: str = '0'):
  10. image, angle = self._pre_process(image)
  11. txts, confs = self._ocr(image)
  12. if int(image_type) == 0:
  13. parser = FrontParser(txts, confs)
  14. elif int(image_type) == 1:
  15. parser = BackParser(txts, confs)
  16. else:
  17. raise Exception('无法识别')
  18. return self._post_process(angle, parser, image_type)
  19. def _pre_process(self, image) -> (np.ndarray, int):
  20. angle = detect_angle(image)
  21. print(angle) # 逆时针
  22. if angle == 180:
  23. image = cv2.rotate(image, cv2.ROTATE_180)
  24. if angle == 90:
  25. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  26. if angle == 270:
  27. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  28. return image, angle
  29. def _ocr(self, image):
  30. # 获取模型检测结果
  31. result = self.ocr.ocr(image, cls=True)
  32. print("------------------")
  33. print(result)
  34. if not result:
  35. raise Exception('无法识别')
  36. confs = [line[1][1] for line in result]
  37. # 将检测到的文字放到一个列表中
  38. txts = [line[1][0] for line in result]
  39. # print("......................................")
  40. # print(txts)
  41. # print("......................................")
  42. return txts, confs
  43. def _post_process(self, angle: int, parser: Parser, image_type: str):
  44. content = parser.parse()
  45. conf = parser.confidence
  46. res = {
  47. "confidence": conf,
  48. "card_type": image_type,
  49. "orientation": (4 - angle // 90) % 4, # 原angle是逆时针,转成顺时针
  50. "name": content["Name"].to_dict(),
  51. "id": content["IDNumber"].to_dict(),
  52. "ethnicity": content["Nationality"].to_dict(),
  53. "gender": content["Gender"].to_dict(),
  54. "birthday": content["Birth"].to_dict(),
  55. "address_province": content["address_province"].to_dict(),
  56. "address_city": content["address_city"].to_dict(),
  57. "address_region": content["address_region"].to_dict(),
  58. "address_detail": content["address_detail"].to_dict(),
  59. "expire_date": content["expire_date"].to_dict()
  60. }
  61. print(res)
  62. return res