ocr.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from dataclasses import dataclass
  2. from core.parser import *
  3. from core.direction import *
  4. import numpy as np
  5. from paddleocr import PaddleOCR
  6. @dataclass
  7. class IdCardOcr:
  8. ocr: PaddleOCR
  9. image: np.ndarray
  10. image_type: str = '0'
  11. def predict(self):
  12. image, angle = self._pre_process()
  13. txts, confs = self._ocr(image)
  14. if int(self.image_type) == 0:
  15. parser = FrontParser(txts, confs)
  16. elif int(self.image_type) == 1:
  17. parser = BackParser(txts, confs)
  18. return self._post_process(angle, parser, self.image_type)
  19. def _pre_process(self) -> (np.ndarray, int):
  20. image = self.image
  21. angle = detect_angle(image)
  22. print(angle) # 逆时针
  23. if angle == 180:
  24. image = cv2.rotate(image, cv2.ROTATE_180)
  25. if angle == 90:
  26. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  27. if angle == 270:
  28. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  29. return image, angle
  30. def _ocr(self, image):
  31. # 获取模型检测结果
  32. result = self.ocr.ocr(image, cls=True)
  33. print("------------------")
  34. print(result)
  35. if not result:
  36. raise Exception('无法识别')
  37. confs = [line[1][1] for line in result]
  38. # 将检测到的文字放到一个列表中
  39. txts = [line[1][0] for line in result]
  40. print("......................................")
  41. print(txts)
  42. print("......................................")
  43. return txts, confs
  44. def _post_process(self, angle: int, parser: Parser, image_type: str):
  45. content = parser.parse()
  46. conf = parser.confidence
  47. return {
  48. "confidence": conf,
  49. "card_type": image_type,
  50. "orientation": (4 - angle // 90) % 4, # 原angle是逆时针,转成顺时针
  51. "name": content["Name"].to_dict(),
  52. "id": content["IDNumber"].to_dict(),
  53. "ethnicity": content["Nationality"].to_dict(),
  54. "gender": content["Gender"].to_dict(),
  55. "birthday": content["Birth"].to_dict(),
  56. "address_province": content["address_province"].to_dict(),
  57. "address_city": content["address_city"].to_dict(),
  58. "address_region": content["address_region"].to_dict(),
  59. "address_detail": content["address_detail"].to_dict(),
  60. "expire_date": content["expire_date"].to_dict()
  61. }