utility.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import platform
  18. import cv2
  19. import numpy as np
  20. import paddle
  21. from PIL import Image, ImageDraw, ImageFont
  22. import math
  23. from paddle import inference
  24. import time
  25. from ppocr.utils.logging import get_logger
  26. def str2bool(v):
  27. return v.lower() in ("true", "t", "1")
  28. def init_args():
  29. parser = argparse.ArgumentParser()
  30. # params for prediction engine
  31. parser.add_argument("--use_gpu", type=str2bool, default=True)
  32. parser.add_argument("--ir_optim", type=str2bool, default=True)
  33. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  34. parser.add_argument("--min_subgraph_size", type=int, default=15)
  35. parser.add_argument("--precision", type=str, default="fp32")
  36. parser.add_argument("--gpu_mem", type=int, default=500)
  37. # params for text detector
  38. parser.add_argument("--image_dir", type=str)
  39. parser.add_argument("--det_algorithm", type=str, default='DB')
  40. parser.add_argument("--det_model_dir", type=str)
  41. parser.add_argument("--det_resize_long", type=float, default=960)
  42. parser.add_argument("--det_limit_side_len", type=float, default=960)
  43. parser.add_argument("--det_limit_type", type=str, default='max')
  44. # DB parmas
  45. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  46. parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
  47. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
  48. parser.add_argument("--max_batch_size", type=int, default=10)
  49. parser.add_argument("--use_dilation", type=str2bool, default=False)
  50. parser.add_argument("--det_db_score_mode", type=str, default="fast")
  51. # EAST parmas
  52. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  53. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  54. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  55. # SAST parmas
  56. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  57. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  58. parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
  59. # PSE parmas
  60. parser.add_argument("--det_pse_thresh", type=float, default=0)
  61. parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
  62. parser.add_argument("--det_pse_min_area", type=float, default=16)
  63. parser.add_argument("--det_pse_box_type", type=str, default='quad')
  64. parser.add_argument("--det_pse_scale", type=int, default=1)
  65. # FCE parmas
  66. parser.add_argument("--scales", type=list, default=[8, 16, 32])
  67. parser.add_argument("--alpha", type=float, default=1.0)
  68. parser.add_argument("--beta", type=float, default=1.0)
  69. parser.add_argument("--fourier_degree", type=int, default=5)
  70. parser.add_argument("--det_fce_box_type", type=str, default='poly')
  71. # params for text recognizer
  72. parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
  73. parser.add_argument("--rec_model_dir", type=str)
  74. parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
  75. parser.add_argument("--rec_batch_num", type=int, default=6)
  76. parser.add_argument("--max_text_length", type=int, default=25)
  77. parser.add_argument(
  78. "--rec_char_dict_path",
  79. type=str,
  80. default="./ppocr/s_utils/ppocr_keys_v1.txt")
  81. parser.add_argument("--use_space_char", type=str2bool, default=True)
  82. parser.add_argument(
  83. "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
  84. parser.add_argument("--drop_score", type=float, default=0.5)
  85. # params for e2e
  86. parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
  87. parser.add_argument("--e2e_model_dir", type=str)
  88. parser.add_argument("--e2e_limit_side_len", type=float, default=768)
  89. parser.add_argument("--e2e_limit_type", type=str, default='max')
  90. # PGNet parmas
  91. parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
  92. parser.add_argument(
  93. "--e2e_char_dict_path", type=str, default="./ppocr/s_utils/ic15_dict.txt")
  94. parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
  95. parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
  96. # params for text classifier
  97. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  98. parser.add_argument("--cls_model_dir", type=str)
  99. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  100. parser.add_argument("--label_list", type=list, default=['0', '180'])
  101. parser.add_argument("--cls_batch_num", type=int, default=6)
  102. parser.add_argument("--cls_thresh", type=float, default=0.9)
  103. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  104. parser.add_argument("--cpu_threads", type=int, default=10)
  105. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  106. parser.add_argument("--warmup", type=str2bool, default=False)
  107. #
  108. parser.add_argument(
  109. "--draw_img_save_dir", type=str, default="./inference_results")
  110. parser.add_argument("--save_crop_res", type=str2bool, default=False)
  111. parser.add_argument("--crop_res_save_dir", type=str, default="./output")
  112. # multi-process
  113. parser.add_argument("--use_mp", type=str2bool, default=False)
  114. parser.add_argument("--total_process_num", type=int, default=1)
  115. parser.add_argument("--process_id", type=int, default=0)
  116. parser.add_argument("--benchmark", type=str2bool, default=False)
  117. parser.add_argument("--save_log_path", type=str, default="./log_output/")
  118. parser.add_argument("--show_log", type=str2bool, default=True)
  119. parser.add_argument("--use_onnx", type=str2bool, default=False)
  120. return parser
  121. def parse_args():
  122. parser = init_args()
  123. return parser.parse_args()
  124. def create_predictor(args, mode, logger):
  125. if mode == "det":
  126. model_dir = args.det_model_dir
  127. elif mode == 'cls':
  128. model_dir = args.cls_model_dir
  129. elif mode == 'rec':
  130. model_dir = args.rec_model_dir
  131. elif mode == 'table':
  132. model_dir = args.table_model_dir
  133. else:
  134. model_dir = args.e2e_model_dir
  135. if model_dir is None:
  136. logger.info("not find {} model file path {}".format(mode, model_dir))
  137. sys.exit(0)
  138. if args.use_onnx:
  139. import onnxruntime as ort
  140. model_file_path = model_dir
  141. if not os.path.exists(model_file_path):
  142. raise ValueError("not find model file path {}".format(
  143. model_file_path))
  144. sess = ort.InferenceSession(model_file_path)
  145. return sess, sess.get_inputs()[0], None, None
  146. else:
  147. model_file_path = model_dir + "/inference.pdmodel"
  148. params_file_path = model_dir + "/inference.pdiparams"
  149. if not os.path.exists(model_file_path):
  150. raise ValueError("not find model file path {}".format(
  151. model_file_path))
  152. if not os.path.exists(params_file_path):
  153. raise ValueError("not find params file path {}".format(
  154. params_file_path))
  155. config = inference.Config(model_file_path, params_file_path)
  156. if hasattr(args, 'precision'):
  157. if args.precision == "fp16" and args.use_tensorrt:
  158. precision = inference.PrecisionType.Half
  159. elif args.precision == "int8":
  160. precision = inference.PrecisionType.Int8
  161. else:
  162. precision = inference.PrecisionType.Float32
  163. else:
  164. precision = inference.PrecisionType.Float32
  165. if args.use_gpu:
  166. gpu_id = get_infer_gpuid()
  167. if gpu_id is None:
  168. logger.warning(
  169. "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
  170. )
  171. config.enable_use_gpu(args.gpu_mem, 0)
  172. if args.use_tensorrt:
  173. config.enable_tensorrt_engine(
  174. workspace_size=1 << 30,
  175. precision_mode=precision,
  176. max_batch_size=args.max_batch_size,
  177. min_subgraph_size=args.min_subgraph_size)
  178. # skip the minmum trt subgraph
  179. use_dynamic_shape = True
  180. if mode == "det":
  181. min_input_shape = {
  182. "x": [1, 3, 50, 50],
  183. "conv2d_92.tmp_0": [1, 120, 20, 20],
  184. "conv2d_91.tmp_0": [1, 24, 10, 10],
  185. "conv2d_59.tmp_0": [1, 96, 20, 20],
  186. "nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
  187. "nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
  188. "conv2d_124.tmp_0": [1, 256, 20, 20],
  189. "nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
  190. "nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
  191. "nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
  192. "elementwise_add_7": [1, 56, 2, 2],
  193. "nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
  194. }
  195. max_input_shape = {
  196. "x": [1, 3, 1536, 1536],
  197. "conv2d_92.tmp_0": [1, 120, 400, 400],
  198. "conv2d_91.tmp_0": [1, 24, 200, 200],
  199. "conv2d_59.tmp_0": [1, 96, 400, 400],
  200. "nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
  201. "conv2d_124.tmp_0": [1, 256, 400, 400],
  202. "nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
  203. "nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
  204. "nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
  205. "nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
  206. "elementwise_add_7": [1, 56, 400, 400],
  207. "nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
  208. }
  209. opt_input_shape = {
  210. "x": [1, 3, 640, 640],
  211. "conv2d_92.tmp_0": [1, 120, 160, 160],
  212. "conv2d_91.tmp_0": [1, 24, 80, 80],
  213. "conv2d_59.tmp_0": [1, 96, 160, 160],
  214. "nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
  215. "nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
  216. "conv2d_124.tmp_0": [1, 256, 160, 160],
  217. "nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
  218. "nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
  219. "nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
  220. "elementwise_add_7": [1, 56, 40, 40],
  221. "nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
  222. }
  223. min_pact_shape = {
  224. "nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
  225. "nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
  226. "nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
  227. "nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
  228. }
  229. max_pact_shape = {
  230. "nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
  231. "nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
  232. "nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
  233. "nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
  234. }
  235. opt_pact_shape = {
  236. "nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
  237. "nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
  238. "nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
  239. "nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
  240. }
  241. min_input_shape.update(min_pact_shape)
  242. max_input_shape.update(max_pact_shape)
  243. opt_input_shape.update(opt_pact_shape)
  244. elif mode == "rec":
  245. if args.rec_algorithm not in ["CRNN", "SVTR_LCNet"]:
  246. use_dynamic_shape = False
  247. imgH = int(args.rec_image_shape.split(',')[-2])
  248. min_input_shape = {"x": [1, 3, imgH, 10]}
  249. max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 2304]}
  250. opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
  251. elif mode == "cls":
  252. min_input_shape = {"x": [1, 3, 48, 10]}
  253. max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
  254. opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
  255. else:
  256. use_dynamic_shape = False
  257. if use_dynamic_shape:
  258. config.set_trt_dynamic_shape_info(
  259. min_input_shape, max_input_shape, opt_input_shape)
  260. else:
  261. config.disable_gpu()
  262. if hasattr(args, "cpu_threads"):
  263. config.set_cpu_math_library_num_threads(args.cpu_threads)
  264. else:
  265. # default cpu threads as 10
  266. config.set_cpu_math_library_num_threads(10)
  267. if args.enable_mkldnn:
  268. # cache 10 different shapes for mkldnn to avoid memory leak
  269. config.set_mkldnn_cache_capacity(10)
  270. config.enable_mkldnn()
  271. if args.precision == "fp16":
  272. config.enable_mkldnn_bfloat16()
  273. # enable memory optim
  274. config.enable_memory_optim()
  275. config.disable_glog_info()
  276. config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
  277. config.delete_pass("matmul_transpose_reshape_fuse_pass")
  278. if mode == 'table':
  279. config.delete_pass("fc_fuse_pass") # not supported for table
  280. config.switch_use_feed_fetch_ops(False)
  281. config.switch_ir_optim(True)
  282. # create predictor
  283. predictor = inference.create_predictor(config)
  284. input_names = predictor.get_input_names()
  285. for name in input_names:
  286. input_tensor = predictor.get_input_handle(name)
  287. output_tensors = get_output_tensors(args, mode, predictor)
  288. return predictor, input_tensor, output_tensors, config
  289. def get_output_tensors(args, mode, predictor):
  290. output_names = predictor.get_output_names()
  291. output_tensors = []
  292. if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
  293. output_name = 'softmax_0.tmp_0'
  294. if output_name in output_names:
  295. return [predictor.get_output_handle(output_name)]
  296. else:
  297. for output_name in output_names:
  298. output_tensor = predictor.get_output_handle(output_name)
  299. output_tensors.append(output_tensor)
  300. else:
  301. for output_name in output_names:
  302. output_tensor = predictor.get_output_handle(output_name)
  303. output_tensors.append(output_tensor)
  304. return output_tensors
  305. def get_infer_gpuid():
  306. sysstr = platform.system()
  307. if sysstr == "Windows":
  308. return 0
  309. if not paddle.fluid.core.is_compiled_with_rocm():
  310. cmd = "env | grep CUDA_VISIBLE_DEVICES"
  311. else:
  312. cmd = "env | grep HIP_VISIBLE_DEVICES"
  313. env_cuda = os.popen(cmd).readlines()
  314. if len(env_cuda) == 0:
  315. return 0
  316. else:
  317. gpu_id = env_cuda[0].strip().split("=")[1]
  318. return int(gpu_id[0])
  319. def draw_e2e_res(dt_boxes, strs, img_path):
  320. src_im = cv2.imread(img_path)
  321. for box, str in zip(dt_boxes, strs):
  322. box = box.astype(np.int32).reshape((-1, 1, 2))
  323. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  324. cv2.putText(
  325. src_im,
  326. str,
  327. org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
  328. fontFace=cv2.FONT_HERSHEY_COMPLEX,
  329. fontScale=0.7,
  330. color=(0, 255, 0),
  331. thickness=1)
  332. return src_im
  333. def draw_text_det_res(dt_boxes, img_path):
  334. src_im = cv2.imread(img_path)
  335. for box in dt_boxes:
  336. box = np.array(box).astype(np.int32).reshape(-1, 2)
  337. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  338. return src_im
  339. def resize_img(img, input_size=600):
  340. """
  341. resize img and limit the longest side of the image to input_size
  342. """
  343. img = np.array(img)
  344. im_shape = img.shape
  345. im_size_max = np.max(im_shape[0:2])
  346. im_scale = float(input_size) / float(im_size_max)
  347. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  348. return img
  349. def draw_ocr(image,
  350. boxes,
  351. txts=None,
  352. scores=None,
  353. drop_score=0.5,
  354. font_path="./doc/fonts/simfang.ttf"):
  355. """
  356. Visualize the results of OCR detection and recognition
  357. args:
  358. image(Image|array): RGB image
  359. boxes(list): boxes with shape(N, 4, 2)
  360. txts(list): the texts
  361. scores(list): txxs corresponding scores
  362. drop_score(float): only scores greater than drop_threshold will be visualized
  363. font_path: the path of font which is used to draw text
  364. return(array):
  365. the visualized img
  366. """
  367. if scores is None:
  368. scores = [1] * len(boxes)
  369. box_num = len(boxes)
  370. for i in range(box_num):
  371. if scores is not None and (scores[i] < drop_score or
  372. math.isnan(scores[i])):
  373. continue
  374. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  375. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  376. if txts is not None:
  377. img = np.array(resize_img(image, input_size=600))
  378. txt_img = text_visual(
  379. txts,
  380. scores,
  381. img_h=img.shape[0],
  382. img_w=600,
  383. threshold=drop_score,
  384. font_path=font_path)
  385. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  386. return img
  387. return image
  388. def draw_ocr_box_txt(image,
  389. boxes,
  390. txts,
  391. scores=None,
  392. drop_score=0.5,
  393. font_path="./doc/simfang.ttf"):
  394. h, w = image.height, image.width
  395. img_left = image.copy()
  396. img_right = Image.new('RGB', (w, h), (255, 255, 255))
  397. import random
  398. random.seed(0)
  399. draw_left = ImageDraw.Draw(img_left)
  400. draw_right = ImageDraw.Draw(img_right)
  401. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  402. if scores is not None and scores[idx] < drop_score:
  403. continue
  404. color = (random.randint(0, 255), random.randint(0, 255),
  405. random.randint(0, 255))
  406. draw_left.polygon(box, fill=color)
  407. draw_right.polygon(
  408. [
  409. box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
  410. box[2][1], box[3][0], box[3][1]
  411. ],
  412. outline=color)
  413. box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
  414. 1])**2)
  415. box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
  416. 1])**2)
  417. if box_height > 2 * box_width:
  418. font_size = max(int(box_width * 0.9), 10)
  419. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  420. cur_y = box[0][1]
  421. for c in txt:
  422. char_size = font.getsize(c)
  423. draw_right.text(
  424. (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
  425. cur_y += char_size[1]
  426. else:
  427. font_size = max(int(box_height * 0.8), 10)
  428. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  429. draw_right.text(
  430. [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
  431. img_left = Image.blend(image, img_left, 0.5)
  432. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  433. img_show.paste(img_left, (0, 0, w, h))
  434. img_show.paste(img_right, (w, 0, w * 2, h))
  435. return np.array(img_show)
  436. def str_count(s):
  437. """
  438. Count the number of Chinese characters,
  439. a single English character and a single number
  440. equal to half the length of Chinese characters.
  441. args:
  442. s(string): the input of string
  443. return(int):
  444. the number of Chinese characters
  445. """
  446. import string
  447. count_zh = count_pu = 0
  448. s_len = len(s)
  449. en_dg_count = 0
  450. for c in s:
  451. if c in string.ascii_letters or c.isdigit() or c.isspace():
  452. en_dg_count += 1
  453. elif c.isalpha():
  454. count_zh += 1
  455. else:
  456. count_pu += 1
  457. return s_len - math.ceil(en_dg_count / 2)
  458. def text_visual(texts,
  459. scores,
  460. img_h=400,
  461. img_w=600,
  462. threshold=0.,
  463. font_path="./doc/simfang.ttf"):
  464. """
  465. create new blank img and draw txt on it
  466. args:
  467. texts(list): the text will be draw
  468. scores(list|None): corresponding score of each txt
  469. img_h(int): the height of blank img
  470. img_w(int): the width of blank img
  471. font_path: the path of font which is used to draw text
  472. return(array):
  473. """
  474. if scores is not None:
  475. assert len(texts) == len(
  476. scores), "The number of txts and corresponding scores must match"
  477. def create_blank_img():
  478. blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
  479. blank_img[:, img_w - 1:] = 0
  480. blank_img = Image.fromarray(blank_img).convert("RGB")
  481. draw_txt = ImageDraw.Draw(blank_img)
  482. return blank_img, draw_txt
  483. blank_img, draw_txt = create_blank_img()
  484. font_size = 20
  485. txt_color = (0, 0, 0)
  486. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  487. gap = font_size + 5
  488. txt_img_list = []
  489. count, index = 1, 0
  490. for idx, txt in enumerate(texts):
  491. index += 1
  492. if scores[idx] < threshold or math.isnan(scores[idx]):
  493. index -= 1
  494. continue
  495. first_line = True
  496. while str_count(txt) >= img_w // font_size - 4:
  497. tmp = txt
  498. txt = tmp[:img_w // font_size - 4]
  499. if first_line:
  500. new_txt = str(index) + ': ' + txt
  501. first_line = False
  502. else:
  503. new_txt = ' ' + txt
  504. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  505. txt = tmp[img_w // font_size - 4:]
  506. if count >= img_h // gap - 1:
  507. txt_img_list.append(np.array(blank_img))
  508. blank_img, draw_txt = create_blank_img()
  509. count = 0
  510. count += 1
  511. if first_line:
  512. new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
  513. else:
  514. new_txt = " " + txt + " " + '%.3f' % (scores[idx])
  515. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  516. # whether add new blank img or not
  517. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  518. txt_img_list.append(np.array(blank_img))
  519. blank_img, draw_txt = create_blank_img()
  520. count = 0
  521. count += 1
  522. txt_img_list.append(np.array(blank_img))
  523. if len(txt_img_list) == 1:
  524. blank_img = np.array(txt_img_list[0])
  525. else:
  526. blank_img = np.concatenate(txt_img_list, axis=1)
  527. return np.array(blank_img)
  528. def base64_to_cv2(b64str):
  529. import base64
  530. data = base64.b64decode(b64str.encode('utf8'))
  531. data = np.fromstring(data, np.uint8)
  532. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  533. return data
  534. def draw_boxes(image, boxes, scores=None, drop_score=0.5):
  535. if scores is None:
  536. scores = [1] * len(boxes)
  537. for (box, score) in zip(boxes, scores):
  538. if score < drop_score:
  539. continue
  540. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  541. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  542. return image
  543. def get_rotate_crop_image(img, points):
  544. '''
  545. img_height, img_width = img.shape[0:2]
  546. left = int(np.min(points[:, 0]))
  547. right = int(np.max(points[:, 0]))
  548. top = int(np.min(points[:, 1]))
  549. bottom = int(np.max(points[:, 1]))
  550. img_crop = img[top:bottom, left:right, :].copy()
  551. points[:, 0] = points[:, 0] - left
  552. points[:, 1] = points[:, 1] - top
  553. '''
  554. assert len(points) == 4, "shape of points must be 4*2"
  555. img_crop_width = int(
  556. max(
  557. np.linalg.norm(points[0] - points[1]),
  558. np.linalg.norm(points[2] - points[3])))
  559. img_crop_height = int(
  560. max(
  561. np.linalg.norm(points[0] - points[3]),
  562. np.linalg.norm(points[1] - points[2])))
  563. pts_std = np.float32([[0, 0], [img_crop_width, 0],
  564. [img_crop_width, img_crop_height],
  565. [0, img_crop_height]])
  566. M = cv2.getPerspectiveTransform(points, pts_std)
  567. dst_img = cv2.warpPerspective(
  568. img,
  569. M, (img_crop_width, img_crop_height),
  570. borderMode=cv2.BORDER_REPLICATE,
  571. flags=cv2.INTER_CUBIC)
  572. dst_img_height, dst_img_width = dst_img.shape[0:2]
  573. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  574. dst_img = np.rot90(dst_img)
  575. return dst_img
  576. def check_gpu(use_gpu):
  577. if use_gpu and not paddle.is_compiled_with_cuda():
  578. use_gpu = False
  579. return use_gpu
  580. if __name__ == '__main__':
  581. pass