ocr.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import time
  2. from dataclasses import dataclass
  3. from typing import Any
  4. from core.line_parser import LineParser
  5. from core.parser import *
  6. from core.direction import *
  7. import numpy as np
  8. from paddleocr import PaddleOCR
  9. @dataclass
  10. class CetOcr:
  11. ocr: PaddleOCR
  12. # 角度探测器
  13. angle_detector: AngleDetector
  14. # 检测
  15. def predict(self, image: np.ndarray) -> ():
  16. image, angle, result, image_type = self._pre_process(image)
  17. cv2.imwrite('dd.jpg', image)
  18. print(f'---------- detect angle: {angle} 角度 --------')
  19. if angle != 0:
  20. _, _, result = self._ocr(image)
  21. return self._post_process(result, angle, image_type)
  22. def _pre_process(self, image) -> (np.ndarray, int, Any):
  23. # pic角度 result(ocr生)
  24. angle, result, image_type = self.angle_detector.detect_angle(image)
  25. if angle == 1:
  26. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  27. if angle == 2:
  28. image = cv2.rotate(image, cv2.ROTATE_180)
  29. if angle == 3:
  30. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  31. return image, angle, result, image_type
  32. def _ocr(self, image):
  33. result = self.ocr.ocr(image, cls=True)
  34. if not result:
  35. raise Exception('无法识别')
  36. confs = [line[1][1] for line in result]
  37. # 将检测到的文字放到一个列表中
  38. txts = [line[1][0] for line in result]
  39. return txts, confs, result
  40. def _post_process(self, result, angle: int, image_type):
  41. filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
  42. line_parser = LineParser(result, filters)
  43. line_result = line_parser.parse()
  44. print('-------------')
  45. print(line_result)
  46. print('-------------')
  47. conf = line_parser.confidence
  48. if int(image_type) == 0:
  49. parser = CETParser(line_result)
  50. elif int(image_type) == 1:
  51. parser = TEMParser(line_result)
  52. else:
  53. raise Exception('无法识别')
  54. ocr_res = parser.parse()
  55. res = {
  56. "confidence": conf,
  57. "orientation": angle, # 原angle是逆时针,转成顺时针
  58. **ocr_res
  59. }
  60. print(res)
  61. return res