ocr.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import time
  2. from dataclasses import dataclass
  3. from typing import Any
  4. from core.line_parser import LineParser
  5. from core.parser import *
  6. from core.direction import *
  7. import numpy as np
  8. from paddleocr import PaddleOCR
  9. # <- 传入pic pic_type
  10. # 1. 旋转pic (to 正向)
  11. # 2. 重写识别pic (get res)
  12. # 3. 行处理res (get res)
  13. # 4. 对res字段逻辑识别 (get dict)
  14. # -> dict
  15. # 身份证OCR
  16. @dataclass
  17. class CetOcr:
  18. ocr: PaddleOCR
  19. # 角度探测器
  20. angle_detector: AngleDetector
  21. # 检测
  22. # <- 传入pic pic_type
  23. # -> dict
  24. def predict(self, image: np.ndarray) -> ():
  25. # 旋转后img angle result(生ocr)
  26. image, angle, result, image_type = self._pre_process(image)
  27. print(f'---------- detect angle: {angle} 角度 --------')
  28. if angle != 0:
  29. _, _, result = self._ocr(image)
  30. return self._post_process(result, angle, image_type)
  31. # 预处理(旋转图片)
  32. # <- img(cv2) img_type
  33. # -> 正向的img(旋转后) 源img角度 result(ocr生)
  34. def _pre_process(self, image) -> (np.ndarray, int, Any):
  35. # pic角度 result(ocr生)
  36. angle, result, image_type = self.angle_detector.detect_angle(image)
  37. if angle == 1:
  38. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  39. if angle == 2:
  40. image = cv2.rotate(image, cv2.ROTATE_180)
  41. if angle == 3:
  42. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  43. return image, angle, result, image_type
  44. # 获取模型检测结果
  45. def _ocr(self, image):
  46. result = self.ocr.ocr(image, cls=True)
  47. # print("------------------")
  48. # print(result)
  49. if not result:
  50. raise Exception('无法识别')
  51. confs = [line[1][1] for line in result]
  52. # 将检测到的文字放到一个列表中
  53. txts = [line[1][0] for line in result]
  54. # print("......................................")
  55. # print(txts)
  56. # print("......................................")
  57. return txts, confs, result
  58. # <- result(正向img_生ocr) angle img_type
  59. # == 对 正向img_res 进行[行处理]
  60. # -> 最后要返回的结果 dict
  61. def _post_process(self, result, angle: int, image_type):
  62. filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
  63. line_parser = LineParser(result, filters)
  64. line_result = line_parser.parse()
  65. print('-------------')
  66. print(line_result)
  67. print('-------------')
  68. conf = line_parser.confidence
  69. if int(image_type) == 0:
  70. parser = CETParser(line_result)
  71. elif int(image_type) == 1:
  72. parser = TEMParser(line_result)
  73. else:
  74. raise Exception('无法识别')
  75. # 字段逻辑处理后对res(dict)
  76. ocr_res = parser.parse()
  77. res = {
  78. "confidence": conf,
  79. "orientation": angle, # 原angle是逆时针,转成顺时针
  80. **ocr_res
  81. }
  82. print(res)
  83. return res
  84. # def _get_type(self, image) -> int: