utility.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import platform
  18. import cv2
  19. import numpy as np
  20. import paddle
  21. from PIL import Image, ImageDraw, ImageFont
  22. import math
  23. from paddle import inference
  24. import time
  25. from ppocr.utils.logging import get_logger
  26. def str2bool(v):
  27. return v.lower() in ("true", "t", "1")
  28. def init_args():
  29. parser = argparse.ArgumentParser()
  30. # params for prediction engine
  31. parser.add_argument("--use_gpu", type=str2bool, default=True)
  32. parser.add_argument("--use_xpu", type=str2bool, default=False)
  33. parser.add_argument("--ir_optim", type=str2bool, default=True)
  34. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  35. parser.add_argument("--min_subgraph_size", type=int, default=15)
  36. parser.add_argument("--precision", type=str, default="fp32")
  37. parser.add_argument("--gpu_mem", type=int, default=500)
  38. # params for text detector
  39. parser.add_argument("--image_dir", type=str)
  40. parser.add_argument("--det_algorithm", type=str, default='DB')
  41. parser.add_argument("--det_model_dir", type=str)
  42. parser.add_argument("--det_resize_long",type=float,default=960)
  43. parser.add_argument("--det_limit_side_len", type=float, default=960)
  44. parser.add_argument("--det_limit_type", type=str, default='max')
  45. # DB parmas
  46. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  47. parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
  48. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
  49. parser.add_argument("--max_batch_size", type=int, default=10)
  50. parser.add_argument("--use_dilation", type=str2bool, default=False)
  51. parser.add_argument("--det_db_score_mode", type=str, default="fast")
  52. parser.add_argument("--vis_seg_map", type=str2bool, default=False)
  53. # EAST parmas
  54. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  55. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  56. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  57. # SAST parmas
  58. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  59. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  60. parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
  61. # PSE parmas
  62. parser.add_argument("--det_pse_thresh", type=float, default=0)
  63. parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
  64. parser.add_argument("--det_pse_min_area", type=float, default=16)
  65. parser.add_argument("--det_pse_box_type", type=str, default='quad')
  66. parser.add_argument("--det_pse_scale", type=int, default=1)
  67. # FCE parmas
  68. parser.add_argument("--scales", type=list, default=[8, 16, 32])
  69. parser.add_argument("--alpha", type=float, default=1.0)
  70. parser.add_argument("--beta", type=float, default=1.0)
  71. parser.add_argument("--fourier_degree", type=int, default=5)
  72. parser.add_argument("--det_fce_box_type", type=str, default='poly')
  73. # params for text recognizer
  74. parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
  75. parser.add_argument("--rec_model_dir", type=str)
  76. parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
  77. parser.add_argument("--rec_batch_num", type=int, default=6)
  78. parser.add_argument("--max_text_length", type=int, default=25)
  79. parser.add_argument(
  80. "--rec_char_dict_path",
  81. type=str,
  82. default="./ppocr/utils/ppocr_keys_v1.txt")
  83. parser.add_argument("--use_space_char", type=str2bool, default=True)
  84. parser.add_argument(
  85. "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
  86. parser.add_argument("--drop_score", type=float, default=0.5)
  87. # params for e2e
  88. parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
  89. parser.add_argument("--e2e_model_dir", type=str)
  90. parser.add_argument("--e2e_limit_side_len", type=float, default=768)
  91. parser.add_argument("--e2e_limit_type", type=str, default='max')
  92. # PGNet parmas
  93. parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
  94. parser.add_argument(
  95. "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
  96. parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
  97. parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
  98. # params for text classifier
  99. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  100. parser.add_argument("--cls_model_dir", type=str)
  101. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  102. parser.add_argument("--label_list", type=list, default=['0', '180'])
  103. parser.add_argument("--cls_batch_num", type=int, default=6)
  104. parser.add_argument("--cls_thresh", type=float, default=0.9)
  105. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  106. parser.add_argument("--cpu_threads", type=int, default=10)
  107. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  108. parser.add_argument("--warmup", type=str2bool, default=False)
  109. #
  110. parser.add_argument(
  111. "--draw_img_save_dir", type=str, default="./inference_results")
  112. parser.add_argument("--save_crop_res", type=str2bool, default=False)
  113. parser.add_argument("--crop_res_save_dir", type=str, default="./output")
  114. # multi-process
  115. parser.add_argument("--use_mp", type=str2bool, default=False)
  116. parser.add_argument("--total_process_num", type=int, default=1)
  117. parser.add_argument("--process_id", type=int, default=0)
  118. parser.add_argument("--benchmark", type=str2bool, default=False)
  119. parser.add_argument("--save_log_path", type=str, default="./log_output/")
  120. parser.add_argument("--show_log", type=str2bool, default=True)
  121. parser.add_argument("--use_onnx", type=str2bool, default=False)
  122. return parser
  123. def parse_args():
  124. parser = init_args()
  125. return parser.parse_args()
  126. def create_predictor(args, mode, logger):
  127. if mode == "det":
  128. model_dir = args.det_model_dir
  129. elif mode == 'cls':
  130. model_dir = args.cls_model_dir
  131. elif mode == 'rec':
  132. model_dir = args.rec_model_dir
  133. elif mode == 'table':
  134. model_dir = args.table_model_dir
  135. else:
  136. model_dir = args.e2e_model_dir
  137. if model_dir is None:
  138. logger.info("not find {} model file path {}".format(mode, model_dir))
  139. sys.exit(0)
  140. if args.use_onnx:
  141. import onnxruntime as ort
  142. model_file_path = model_dir
  143. if not os.path.exists(model_file_path):
  144. raise ValueError("not find model file path {}".format(
  145. model_file_path))
  146. sess = ort.InferenceSession(model_file_path)
  147. return sess, sess.get_inputs()[0], None, None
  148. else:
  149. model_file_path = model_dir + "/inference.pdmodel"
  150. params_file_path = model_dir + "/inference.pdiparams"
  151. if not os.path.exists(model_file_path):
  152. raise ValueError("not find model file path {}".format(
  153. model_file_path))
  154. if not os.path.exists(params_file_path):
  155. raise ValueError("not find params file path {}".format(
  156. params_file_path))
  157. config = inference.Config(model_file_path, params_file_path)
  158. if hasattr(args, 'precision'):
  159. if args.precision == "fp16" and args.use_tensorrt:
  160. precision = inference.PrecisionType.Half
  161. elif args.precision == "int8":
  162. precision = inference.PrecisionType.Int8
  163. else:
  164. precision = inference.PrecisionType.Float32
  165. else:
  166. precision = inference.PrecisionType.Float32
  167. if args.use_gpu:
  168. gpu_id = get_infer_gpuid()
  169. if gpu_id is None:
  170. logger.warning(
  171. "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
  172. )
  173. config.enable_use_gpu(args.gpu_mem, 0)
  174. if args.use_tensorrt:
  175. config.enable_tensorrt_engine(
  176. workspace_size=1 << 30,
  177. precision_mode=precision,
  178. max_batch_size=args.max_batch_size,
  179. min_subgraph_size=args.min_subgraph_size)
  180. # skip the minmum trt subgraph
  181. use_dynamic_shape = True
  182. if mode == "det":
  183. min_input_shape = {
  184. "x": [1, 3, 50, 50],
  185. "conv2d_92.tmp_0": [1, 120, 20, 20],
  186. "conv2d_91.tmp_0": [1, 24, 10, 10],
  187. "conv2d_59.tmp_0": [1, 96, 20, 20],
  188. "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
  189. "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
  190. "conv2d_124.tmp_0": [1, 256, 20, 20],
  191. "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
  192. "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
  193. "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
  194. "elementwise_add_7": [1, 56, 2, 2],
  195. "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
  196. }
  197. max_input_shape = {
  198. "x": [1, 3, 1536, 1536],
  199. "conv2d_92.tmp_0": [1, 120, 400, 400],
  200. "conv2d_91.tmp_0": [1, 24, 200, 200],
  201. "conv2d_59.tmp_0": [1, 96, 400, 400],
  202. "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
  203. "conv2d_124.tmp_0": [1, 256, 400, 400],
  204. "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
  205. "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
  206. "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
  207. "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
  208. "elementwise_add_7": [1, 56, 400, 400],
  209. "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
  210. }
  211. opt_input_shape = {
  212. "x": [1, 3, 640, 640],
  213. "conv2d_92.tmp_0": [1, 120, 160, 160],
  214. "conv2d_91.tmp_0": [1, 24, 80, 80],
  215. "conv2d_59.tmp_0": [1, 96, 160, 160],
  216. "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
  217. "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
  218. "conv2d_124.tmp_0": [1, 256, 160, 160],
  219. "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
  220. "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
  221. "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
  222. "elementwise_add_7": [1, 56, 40, 40],
  223. "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
  224. }
  225. min_pact_shape = {
  226. "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
  227. "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
  228. "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
  229. "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
  230. }
  231. max_pact_shape = {
  232. "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
  233. "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
  234. "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
  235. "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
  236. }
  237. opt_pact_shape = {
  238. "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
  239. "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
  240. "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
  241. "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
  242. }
  243. min_input_shape.update(min_pact_shape)
  244. max_input_shape.update(max_pact_shape)
  245. opt_input_shape.update(opt_pact_shape)
  246. elif mode == "rec":
  247. if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
  248. use_dynamic_shape = False
  249. imgH = int(args.rec_image_shape.split(',')[-2])
  250. min_input_shape = {"x": [1, 3, imgH, 10]}
  251. max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
  252. opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
  253. config.exp_disable_tensorrt_ops(["transpose2"])
  254. elif mode == "cls":
  255. min_input_shape = {"x": [1, 3, 48, 10]}
  256. max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
  257. opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
  258. else:
  259. use_dynamic_shape = False
  260. if use_dynamic_shape:
  261. config.set_trt_dynamic_shape_info(
  262. min_input_shape, max_input_shape, opt_input_shape)
  263. elif args.use_xpu:
  264. config.enable_xpu(10 * 1024 * 1024)
  265. else:
  266. config.disable_gpu()
  267. if hasattr(args, "cpu_threads"):
  268. config.set_cpu_math_library_num_threads(args.cpu_threads)
  269. else:
  270. # default cpu threads as 10
  271. config.set_cpu_math_library_num_threads(10)
  272. if args.enable_mkldnn:
  273. # cache 10 different shapes for mkldnn to avoid memory leak
  274. config.set_mkldnn_cache_capacity(10)
  275. config.enable_mkldnn()
  276. if args.precision == "fp16":
  277. config.enable_mkldnn_bfloat16()
  278. # enable memory optim
  279. config.enable_memory_optim()
  280. config.disable_glog_info()
  281. config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
  282. config.delete_pass("matmul_transpose_reshape_fuse_pass")
  283. if mode == 'table':
  284. config.delete_pass("fc_fuse_pass") # not supported for table
  285. config.switch_use_feed_fetch_ops(False)
  286. config.switch_ir_optim(True)
  287. # create predictor
  288. predictor = inference.create_predictor(config)
  289. input_names = predictor.get_input_names()
  290. for name in input_names:
  291. input_tensor = predictor.get_input_handle(name)
  292. output_tensors = get_output_tensors(args, mode, predictor)
  293. return predictor, input_tensor, output_tensors, config
  294. def get_output_tensors(args, mode, predictor):
  295. output_names = predictor.get_output_names()
  296. output_tensors = []
  297. if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
  298. output_name = 'softmax_0.tmp_0'
  299. if output_name in output_names:
  300. return [predictor.get_output_handle(output_name)]
  301. else:
  302. for output_name in output_names:
  303. output_tensor = predictor.get_output_handle(output_name)
  304. output_tensors.append(output_tensor)
  305. else:
  306. for output_name in output_names:
  307. output_tensor = predictor.get_output_handle(output_name)
  308. output_tensors.append(output_tensor)
  309. return output_tensors
  310. def get_infer_gpuid():
  311. sysstr = platform.system()
  312. if sysstr == "Windows":
  313. return 0
  314. if not paddle.fluid.core.is_compiled_with_rocm():
  315. cmd = "env | grep CUDA_VISIBLE_DEVICES"
  316. else:
  317. cmd = "env | grep HIP_VISIBLE_DEVICES"
  318. env_cuda = os.popen(cmd).readlines()
  319. if len(env_cuda) == 0:
  320. return 0
  321. else:
  322. gpu_id = env_cuda[0].strip().split("=")[1]
  323. return int(gpu_id[0])
  324. def draw_e2e_res(dt_boxes, strs, img_path):
  325. src_im = cv2.imread(img_path)
  326. for box, str in zip(dt_boxes, strs):
  327. box = box.astype(np.int32).reshape((-1, 1, 2))
  328. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  329. cv2.putText(
  330. src_im,
  331. str,
  332. org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
  333. fontFace=cv2.FONT_HERSHEY_COMPLEX,
  334. fontScale=0.7,
  335. color=(0, 255, 0),
  336. thickness=1)
  337. return src_im
  338. def draw_text_det_res(dt_boxes, img_path):
  339. src_im = cv2.imread(img_path)
  340. for box in dt_boxes:
  341. box = np.array(box).astype(np.int32).reshape(-1, 2)
  342. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  343. return src_im
  344. def resize_img(img, input_size=600):
  345. """
  346. resize img and limit the longest side of the image to input_size
  347. """
  348. img = np.array(img)
  349. im_shape = img.shape
  350. im_size_max = np.max(im_shape[0:2])
  351. im_scale = float(input_size) / float(im_size_max)
  352. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  353. return img
  354. def draw_ocr(image,
  355. boxes,
  356. txts=None,
  357. scores=None,
  358. drop_score=0.5,
  359. font_path="./doc/fonts/simfang.ttf"):
  360. """
  361. Visualize the results of OCR detection and recognition
  362. args:
  363. image(Image|array): RGB image
  364. boxes(list): boxes with shape(N, 4, 2)
  365. txts(list): the texts
  366. scores(list): txxs corresponding scores
  367. drop_score(float): only scores greater than drop_threshold will be visualized
  368. font_path: the path of font which is used to draw text
  369. return(array):
  370. the visualized img
  371. """
  372. if scores is None:
  373. scores = [1] * len(boxes)
  374. box_num = len(boxes)
  375. for i in range(box_num):
  376. if scores is not None and (scores[i] < drop_score or
  377. math.isnan(scores[i])):
  378. continue
  379. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  380. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  381. if txts is not None:
  382. img = np.array(resize_img(image, input_size=600))
  383. txt_img = text_visual(
  384. txts,
  385. scores,
  386. img_h=img.shape[0],
  387. img_w=600,
  388. threshold=drop_score,
  389. font_path=font_path)
  390. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  391. return img
  392. return image
  393. def draw_ocr_box_txt(image,
  394. boxes,
  395. txts,
  396. scores=None,
  397. drop_score=0.5,
  398. font_path="./doc/simfang.ttf"):
  399. h, w = image.height, image.width
  400. img_left = image.copy()
  401. img_right = Image.new('RGB', (w, h), (255, 255, 255))
  402. import random
  403. random.seed(0)
  404. draw_left = ImageDraw.Draw(img_left)
  405. draw_right = ImageDraw.Draw(img_right)
  406. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  407. if scores is not None and scores[idx] < drop_score:
  408. continue
  409. color = (random.randint(0, 255), random.randint(0, 255),
  410. random.randint(0, 255))
  411. draw_left.polygon(box, fill=color)
  412. draw_right.polygon(
  413. [
  414. box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
  415. box[2][1], box[3][0], box[3][1]
  416. ],
  417. outline=color)
  418. box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
  419. 1])**2)
  420. box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
  421. 1])**2)
  422. if box_height > 2 * box_width:
  423. font_size = max(int(box_width * 0.9), 10)
  424. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  425. cur_y = box[0][1]
  426. for c in txt:
  427. char_size = font.getsize(c)
  428. draw_right.text(
  429. (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
  430. cur_y += char_size[1]
  431. else:
  432. font_size = max(int(box_height * 0.8), 10)
  433. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  434. draw_right.text(
  435. [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
  436. img_left = Image.blend(image, img_left, 0.5)
  437. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  438. img_show.paste(img_left, (0, 0, w, h))
  439. img_show.paste(img_right, (w, 0, w * 2, h))
  440. return np.array(img_show)
  441. def str_count(s):
  442. """
  443. Count the number of Chinese characters,
  444. a single English character and a single number
  445. equal to half the length of Chinese characters.
  446. args:
  447. s(string): the input of string
  448. return(int):
  449. the number of Chinese characters
  450. """
  451. import string
  452. count_zh = count_pu = 0
  453. s_len = len(s)
  454. en_dg_count = 0
  455. for c in s:
  456. if c in string.ascii_letters or c.isdigit() or c.isspace():
  457. en_dg_count += 1
  458. elif c.isalpha():
  459. count_zh += 1
  460. else:
  461. count_pu += 1
  462. return s_len - math.ceil(en_dg_count / 2)
  463. def text_visual(texts,
  464. scores,
  465. img_h=400,
  466. img_w=600,
  467. threshold=0.,
  468. font_path="./doc/simfang.ttf"):
  469. """
  470. create new blank img and draw txt on it
  471. args:
  472. texts(list): the text will be draw
  473. scores(list|None): corresponding score of each txt
  474. img_h(int): the height of blank img
  475. img_w(int): the width of blank img
  476. font_path: the path of font which is used to draw text
  477. return(array):
  478. """
  479. if scores is not None:
  480. assert len(texts) == len(
  481. scores), "The number of txts and corresponding scores must match"
  482. def create_blank_img():
  483. blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
  484. blank_img[:, img_w - 1:] = 0
  485. blank_img = Image.fromarray(blank_img).convert("RGB")
  486. draw_txt = ImageDraw.Draw(blank_img)
  487. return blank_img, draw_txt
  488. blank_img, draw_txt = create_blank_img()
  489. font_size = 20
  490. txt_color = (0, 0, 0)
  491. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  492. gap = font_size + 5
  493. txt_img_list = []
  494. count, index = 1, 0
  495. for idx, txt in enumerate(texts):
  496. index += 1
  497. if scores[idx] < threshold or math.isnan(scores[idx]):
  498. index -= 1
  499. continue
  500. first_line = True
  501. while str_count(txt) >= img_w // font_size - 4:
  502. tmp = txt
  503. txt = tmp[:img_w // font_size - 4]
  504. if first_line:
  505. new_txt = str(index) + ': ' + txt
  506. first_line = False
  507. else:
  508. new_txt = ' ' + txt
  509. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  510. txt = tmp[img_w // font_size - 4:]
  511. if count >= img_h // gap - 1:
  512. txt_img_list.append(np.array(blank_img))
  513. blank_img, draw_txt = create_blank_img()
  514. count = 0
  515. count += 1
  516. if first_line:
  517. new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
  518. else:
  519. new_txt = " " + txt + " " + '%.3f' % (scores[idx])
  520. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  521. # whether add new blank img or not
  522. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  523. txt_img_list.append(np.array(blank_img))
  524. blank_img, draw_txt = create_blank_img()
  525. count = 0
  526. count += 1
  527. txt_img_list.append(np.array(blank_img))
  528. if len(txt_img_list) == 1:
  529. blank_img = np.array(txt_img_list[0])
  530. else:
  531. blank_img = np.concatenate(txt_img_list, axis=1)
  532. return np.array(blank_img)
  533. def base64_to_cv2(b64str):
  534. import base64
  535. data = base64.b64decode(b64str.encode('utf8'))
  536. data = np.frombuffer(data, np.uint8)
  537. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  538. return data
  539. def draw_boxes(image, boxes, scores=None, drop_score=0.5):
  540. if scores is None:
  541. scores = [1] * len(boxes)
  542. for (box, score) in zip(boxes, scores):
  543. if score < drop_score:
  544. continue
  545. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  546. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  547. return image
  548. def get_rotate_crop_image(img, points):
  549. '''
  550. img_height, img_width = img.shape[0:2]
  551. left = int(np.min(points[:, 0]))
  552. right = int(np.max(points[:, 0]))
  553. top = int(np.min(points[:, 1]))
  554. bottom = int(np.max(points[:, 1]))
  555. img_crop = img[top:bottom, left:right, :].copy()
  556. points[:, 0] = points[:, 0] - left
  557. points[:, 1] = points[:, 1] - top
  558. '''
  559. assert len(points) == 4, "shape of points must be 4*2"
  560. img_crop_width = int(
  561. max(
  562. np.linalg.norm(points[0] - points[1]),
  563. np.linalg.norm(points[2] - points[3])))
  564. img_crop_height = int(
  565. max(
  566. np.linalg.norm(points[0] - points[3]),
  567. np.linalg.norm(points[1] - points[2])))
  568. pts_std = np.float32([[0, 0], [img_crop_width, 0],
  569. [img_crop_width, img_crop_height],
  570. [0, img_crop_height]])
  571. M = cv2.getPerspectiveTransform(points, pts_std)
  572. dst_img = cv2.warpPerspective(
  573. img,
  574. M, (img_crop_width, img_crop_height),
  575. borderMode=cv2.BORDER_REPLICATE,
  576. flags=cv2.INTER_CUBIC)
  577. dst_img_height, dst_img_width = dst_img.shape[0:2]
  578. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  579. dst_img = np.rot90(dst_img)
  580. return dst_img
  581. def check_gpu(use_gpu):
  582. if use_gpu and not paddle.is_compiled_with_cuda():
  583. use_gpu = False
  584. return use_gpu
  585. if __name__ == '__main__':
  586. pass