ocr.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import time
  2. from dataclasses import dataclass
  3. from typing import Any
  4. from blfe_core.line_parser import LineParser
  5. from blfe_core.parser import *
  6. from blfe_core.direction import *
  7. import numpy as np
  8. from paddleocr import PaddleOCR
  9. @dataclass
  10. class BusinessLicenseOcr:
  11. ocr: PaddleOCR
  12. angle_detector: AngleDetector
  13. def predict(self, image: np.ndarray) -> ():
  14. image, angle, result = self._pre_process(image)
  15. print(f'---------- detect angle: {angle} 角度 --------')
  16. if angle != 0:
  17. _, _, result = self._ocr(image)
  18. image_type = self._type(result)
  19. # 去除 市场监督 水印
  20. for i_k, i_v in enumerate(result):
  21. if '市场监督' in i_v[1][0] and len(i_v[1][0]) < 7:
  22. del result[i_k]
  23. break
  24. return self._post_process(result, angle, image_type, image)
  25. def _pre_process(self, image) -> (np.ndarray, int, Any):
  26. angle, result= self.angle_detector.detect_angle(image)
  27. if angle == 1:
  28. image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  29. if angle == 2:
  30. image = cv2.rotate(image, cv2.ROTATE_180)
  31. if angle == 3:
  32. image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
  33. return image, angle, result
  34. def _type(self, result):
  35. anchor = False
  36. code = False
  37. for res in result:
  38. txt = res[1][0]
  39. if "营业执照" in txt:
  40. anchor = res
  41. if "统一社" in txt or "会信用" in txt or "用代码" in txt:
  42. code = res
  43. return 0 if (code and anchor) and (code[0][0][0] < anchor[0][0][0]) else 1
  44. # 获取模型检测结果
  45. def _ocr(self, image):
  46. result = self.ocr.ocr(image, cls=True)
  47. if not result:
  48. raise Exception('无法识别')
  49. confs = [line[1][1] for line in result]
  50. txts = [line[1][0] for line in result]
  51. return txts, confs, result
  52. def _post_process(self, result, angle: int, image_type, image:np.ndarray):
  53. filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
  54. line_parser = LineParser(result, filters)
  55. line_result = line_parser.parse()
  56. print('-------------')
  57. print(line_result)
  58. print('-------------')
  59. conf = line_parser.confidence
  60. if image_type == 0:
  61. parser = BusinessLicenseParser0(line_result, image, result, self.ocr)
  62. if image_type == 1:
  63. parser = BusinessLicenseParser1(line_result, image, result, self.ocr)
  64. ocr_res = parser.parse()
  65. res = {
  66. "confidence": conf,
  67. "orientation": angle, # 原angle是逆时针,转成顺时针
  68. **ocr_res
  69. }
  70. print(res)
  71. return res