cutimgs.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import cv2
  2. import os
  3. import matplotlib.pyplot as plt
  4. path = "./imgs/"
  5. cut_img_path = "./cut_img/"
  6. d_map = {0: 'code', 1: 'logo', 2: 'style', 3: 'table', 4: 'text', 5: 'title'}
  7. def makedir(dir):
  8. if not os.path.exists(dir):
  9. os.makedirs(dir)
  10. def xywh2lrbt(img_w, img_h, box):
  11. c, x, y, w, h = int(box[0]), float(box[1]), float(box[2]), float(box[3]), float(box[4])
  12. print(c, x, y, w, h)
  13. print("+++++++++++++++++")
  14. x = x * img_w # 中心坐标x
  15. w = w * img_w # box的宽
  16. y = y * img_h # 中心坐标y
  17. h = h * img_h # box的高
  18. print(c, x, y, w, h)
  19. lt_x, lt_y = max(0, x - w / 2), max(0, y - h / 2) # left_top_x, left_top_y
  20. lb_x, lb_y = max(0, x - w / 2), min(img_h, y + h / 2) # left_bottom_x, left_bottom_y
  21. rb_x, rb_y = min(img_w, x + w / 2), min(img_h, y + h / 2) # right_bottom_x, right_bottom_y
  22. rt_x, rt_y = min(img_w, x + w / 2), max(0, y - h / 2) # right_top_x, right_bottom_y
  23. lrbt = [[lt_x, lt_y], [lb_x, lb_y], [rb_x, rb_y], [rt_x, rt_y]]
  24. print(lrbt)
  25. return lrbt, c
  26. img_cls = os.listdir(path) # 获取train,valid,test目录名称
  27. print(img_cls)
  28. # 遍历每个目录,拿到图片和标签
  29. for img_cla in img_cls:
  30. # print(img_cla)
  31. images_path = path + img_cla + "/images/"
  32. labels_path = path + img_cla + "/labels/"
  33. list_images = os.listdir(images_path)
  34. list_labels = os.listdir(labels_path)
  35. for m in range(len(list_images)):
  36. image = list_images[m]
  37. img_name = os.path.splitext(image)[0]
  38. if image != "784_jpg.rf.8a0d3a961866d9bda06a53d9bf2fbc40.jpg":
  39. continue
  40. for n in range(len(list_labels)):
  41. label = list_labels[n]
  42. label_name = os.path.splitext(label)[0]
  43. if img_name != label_name:
  44. continue
  45. else:
  46. # print(image)
  47. # print(label)
  48. # print(img_cla)
  49. # print("==================")
  50. img = cv2.imread(images_path + image)
  51. img_h, img_w, _ = img.shape
  52. arr_label = [] # labels的边界框,存储四个顶点坐标
  53. ori_clas = [] # label的真实类别
  54. with open(labels_path + label, mode="r", encoding="utf-8") as f:
  55. lines = f.readlines()
  56. # some labels are [],-----> image not idea
  57. if not lines:
  58. break
  59. for line in lines:
  60. line = line.split(" ")
  61. lrbt, c = xywh2lrbt(img_w, img_h, line)
  62. arr_label.append(lrbt)
  63. ori_clas.append(c)
  64. num = 0
  65. for i in range(len(arr_label)):
  66. lt_x, lt_y = int(arr_label[i][0][0]), int(arr_label[i][0][1])
  67. rb_x, rb_y = int(arr_label[i][2][0]), int(arr_label[i][2][1])
  68. cls_num = ori_clas[i]
  69. cls = d_map[cls_num]
  70. print(cls)
  71. makedir(cut_img_path + cls)
  72. # 切割真实标签,并保存到本地
  73. # print(lt_y)
  74. # print(rb_y)
  75. # print(lt_x)
  76. # print(rb_x)
  77. if lt_x >= rb_x or lt_y >= rb_y:
  78. continue
  79. cut_ori_img = img[lt_y:rb_y, lt_x:rb_x, :]
  80. # plt.imshow(cut_ori_img[:,:,::-1])
  81. # plt.show()
  82. save_cut_ori_img = cut_img_path + cls + "/" + str(img_cla) + "_" + str(m) + "_" + str(cls) + "_" + str(num) + ".jpg"
  83. # cv2.imwrite(save_cut_ori_img, cut_ori_img)
  84. num += 1
  85. break