123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import cv2
- import os
- import matplotlib.pyplot as plt
- path = "./imgs/"
- cut_img_path = "./cut_img/"
- d_map = {0: 'code', 1: 'logo', 2: 'style', 3: 'table', 4: 'text', 5: 'title'}
- def makedir(dir):
- if not os.path.exists(dir):
- os.makedirs(dir)
- def xywh2lrbt(img_w, img_h, box):
- c, x, y, w, h = int(box[0]), float(box[1]), float(box[2]), float(box[3]), float(box[4])
- print(c, x, y, w, h)
- print("+++++++++++++++++")
- x = x * img_w # 中心坐标x
- w = w * img_w # box的宽
- y = y * img_h # 中心坐标y
- h = h * img_h # box的高
- print(c, x, y, w, h)
- lt_x, lt_y = max(0, x - w / 2), max(0, y - h / 2) # left_top_x, left_top_y
- lb_x, lb_y = max(0, x - w / 2), min(img_h, y + h / 2) # left_bottom_x, left_bottom_y
- rb_x, rb_y = min(img_w, x + w / 2), min(img_h, y + h / 2) # right_bottom_x, right_bottom_y
- rt_x, rt_y = min(img_w, x + w / 2), max(0, y - h / 2) # right_top_x, right_bottom_y
- lrbt = [[lt_x, lt_y], [lb_x, lb_y], [rb_x, rb_y], [rt_x, rt_y]]
- print(lrbt)
- return lrbt, c
- img_cls = os.listdir(path) # 获取train,valid,test目录名称
- print(img_cls)
- # 遍历每个目录,拿到图片和标签
- for img_cla in img_cls:
- # print(img_cla)
- images_path = path + img_cla + "/images/"
- labels_path = path + img_cla + "/labels/"
- list_images = os.listdir(images_path)
- list_labels = os.listdir(labels_path)
- for m in range(len(list_images)):
- image = list_images[m]
- img_name = os.path.splitext(image)[0]
- if image != "784_jpg.rf.8a0d3a961866d9bda06a53d9bf2fbc40.jpg":
- continue
- for n in range(len(list_labels)):
- label = list_labels[n]
- label_name = os.path.splitext(label)[0]
- if img_name != label_name:
- continue
- else:
- # print(image)
- # print(label)
- # print(img_cla)
- # print("==================")
- img = cv2.imread(images_path + image)
- img_h, img_w, _ = img.shape
- arr_label = [] # labels的边界框,存储四个顶点坐标
- ori_clas = [] # label的真实类别
- with open(labels_path + label, mode="r", encoding="utf-8") as f:
- lines = f.readlines()
- # some labels are [],-----> image not idea
- if not lines:
- break
- for line in lines:
- line = line.split(" ")
- lrbt, c = xywh2lrbt(img_w, img_h, line)
- arr_label.append(lrbt)
- ori_clas.append(c)
- num = 0
- for i in range(len(arr_label)):
- lt_x, lt_y = int(arr_label[i][0][0]), int(arr_label[i][0][1])
- rb_x, rb_y = int(arr_label[i][2][0]), int(arr_label[i][2][1])
- cls_num = ori_clas[i]
- cls = d_map[cls_num]
- print(cls)
- makedir(cut_img_path + cls)
- # 切割真实标签,并保存到本地
- # print(lt_y)
- # print(rb_y)
- # print(lt_x)
- # print(rb_x)
- if lt_x >= rb_x or lt_y >= rb_y:
- continue
- cut_ori_img = img[lt_y:rb_y, lt_x:rb_x, :]
- # plt.imshow(cut_ori_img[:,:,::-1])
- # plt.show()
- save_cut_ori_img = cut_img_path + cls + "/" + str(img_cla) + "_" + str(m) + "_" + str(cls) + "_" + str(num) + ".jpg"
- # cv2.imwrite(save_cut_ori_img, cut_ori_img)
- num += 1
- break
|