123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import time
- from dataclasses import dataclass
- from typing import Any
- from blfe_core.line_parser import LineParser
- from blfe_core.parser import *
- from blfe_core.direction import *
- import numpy as np
- from paddleocr import PaddleOCR
- @dataclass
- class BusinessLicenseOcr:
- ocr: PaddleOCR
- angle_detector: AngleDetector
- def predict(self, image: np.ndarray) -> ():
- image, angle, result = self._pre_process(image)
- print(f'---------- detect angle: {angle} 角度 --------')
- if angle != 0:
- _, _, result = self._ocr(image)
- image_type = self._type(result)
- # 去除 市场监督 水印
- for i_k, i_v in enumerate(result):
- if '市场监督' in i_v[1][0] and len(i_v[1][0]) < 7:
- del result[i_k]
- break
- return self._post_process(result, angle, image_type, image)
- def _pre_process(self, image) -> (np.ndarray, int, Any):
- angle, result= self.angle_detector.detect_angle(image)
- if angle == 1:
- image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
- if angle == 2:
- image = cv2.rotate(image, cv2.ROTATE_180)
- if angle == 3:
- image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
- return image, angle, result
- def _type(self, result):
- anchor = False
- code = False
- for res in result:
- txt = res[1][0]
- if "营业执照" in txt:
- anchor = res
- if "统一社" in txt or "会信用" in txt or "用代码" in txt:
- code = res
- return 0 if (code and anchor) and (code[0][0][0] < anchor[0][0][0]) else 1
- # 获取模型检测结果
- def _ocr(self, image):
- result = self.ocr.ocr(image, cls=True)
- if not result:
- raise Exception('无法识别')
- confs = [line[1][1] for line in result]
- txts = [line[1][0] for line in result]
- return txts, confs, result
- def _post_process(self, result, angle: int, image_type, image:np.ndarray):
- filters = [lambda x: x.is_slope, lambda x: x.txt.replace(' ', '').encode('utf-8').isalpha()]
- line_parser = LineParser(result, filters)
- line_result = line_parser.parse()
- print('-------------')
- print(line_result)
- print('-------------')
- conf = line_parser.confidence
- if image_type == 0:
- parser = BusinessLicenseParser0(line_result, image, result, self.ocr)
- if image_type == 1:
- parser = BusinessLicenseParser1(line_result, image, result, self.ocr)
- ocr_res = parser.parse()
- res = {
- "confidence": conf,
- "orientation": angle, # 原angle是逆时针,转成顺时针
- **ocr_res
- }
- print(res)
- return res
|