Raychar 2 роки тому
батько
коміт
ba9ddf64fc
2 змінених файлів з 292 додано та 0 видалено
  1. 182 0
      create_md.py
  2. 110 0
      cutimgs.py

+ 182 - 0
create_md.py

@@ -0,0 +1,182 @@
+"""
+function: 生成md文件
+"""
+from mdutils.mdutils import MdUtils
+import os
+import cv2
+import numpy as np
+import requests
+
+d_map = {0: 'code', 1: 'logo', 2: 'style', 3: 'table', 4: 'text', 5: 'title'}
+images_path = "./train/images/"
+labels_path = "./train/labels/"
+cut_ori_img_path = "./train/cut_ori_imgs/"
+cut_predict_img_path = "./train/cut_predict_imgs/"
+md_cut_ori_img_path = "../train/cut_ori_imgs/"
+md_cut_predict_img_path = "../train/cut_predict_imgs/"
+md_path = "mdfiles/"
+
+
+# 标签是经过归一化的,需要变回来
+# xywh格式 ---> box四个顶点的坐标
+def xywh2lrbt(img_w, img_h, box):
+    c, x, y, w, h = int(box[0]), float(box[1]), float(box[2]), float(box[3]), float(box[4])
+    x = x * img_w  # 中心坐标x
+    w = w * img_w  # box的宽
+    y = y * img_h  # 中心坐标y
+    h = h * img_h  # box的高
+
+    lt_x, lt_y = x - w / 2, y - h / 2  # left_top_x, left_top_y
+    lb_x, lb_y = x - w / 2, y + h / 2  # left_bottom_x, left_bottom_y
+    rb_x, rb_y = x + w / 2, y + h / 2  # right_bottom_x, right_bottom_y
+    rt_x, rt_y = x + w / 2, y - h / 2  # right_top_x, right_bottom_y
+
+    lrbt = [[lt_x, lt_y], [lb_x, lb_y], [rb_x, rb_y], [rt_x, rt_y]]
+
+    return lrbt, c
+
+
+def IOU(rect1, rect2):
+    xmin1, ymin1, xmax1, ymax1 = rect1
+    xmin2, ymin2, xmax2, ymax2 = rect2
+    s1 = (xmax1 - xmin1) * (ymax1 - ymin1)
+    s2 = (xmax2 - xmin2) * (ymax2 - ymin2)
+
+    sum_area = s1 + s2
+
+    left = max(xmin2, xmin1)
+    right = min(xmax2, xmax1)
+    top = max(ymin2, ymin1)
+    bottom = min(ymax2, ymax1)
+
+    if left >= right or top >= bottom:
+        return 0
+
+    intersection = (right - left) * (bottom - top)
+    return intersection / (sum_area - intersection) * 1.0
+
+
+def send_requests(img_path):
+    b_img_list = []
+    b_img_list.append(('file_list', ('image.jpeg', open(img_path, 'rb'), 'image/jpeg')))
+    payload = {'model_name': 'ocr-layout', 'img_size': 1920, 'download_image': False}
+    response = requests.request("POST", "http://192.168.199.249:4869/detect", headers={}, data=payload,
+                                files=b_img_list)
+    result = eval(str(response.text))
+    result = eval(str(result['result']))
+    return result
+
+
+if __name__ == "__main__":
+    list_images = os.listdir(images_path)
+    list_labels = os.listdir(labels_path)
+    all_acc = 0
+    all_label_len = 0
+
+    # 遍历每张图片
+    for m in range(len(list_images)):
+        each_acc_map = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}  # 每张图片对应的各自预测正确数量
+        each_acc = 0  # 每张图片总共预测正确的数量
+
+        image = list_images[m]
+        img_name = image.split(".")[0]
+        # if image != "125-------------_jpg.rf.9a0d958c405ffb596b4022e13f0686b7.jpg":
+        #     continue
+        for n in range(len(list_labels)):
+            label = list_labels[n]
+            label_name = label.split(".")[0]
+            if img_name != label_name:
+                continue
+            if img_name == label_name:
+                print(image)
+                print(label)
+                print("===========")
+                # 每张图都创建一个md文件
+                mdFile = MdUtils(file_name=md_path+'yq_0819_' + str(img_name), title='yq_0819_' + str(img_name) + '_iou>=0.5')
+                list_item = ["ori_cla", "predict_cla", "cut_ori_img", "cut_predict_img", "isTrue"]
+
+                result = send_requests(images_path + image)
+
+                img = cv2.imread(images_path + image)
+                img_h, img_w, _ = img.shape
+                # 获取原图上的切割结果
+                arr_label = []  # labels的边界框,存储四个顶点坐标
+                ori_clas = []  # label的真实类别
+                with open(labels_path + label, mode="r", encoding="utf-8") as f:
+                    lines = f.readlines()
+                for line in lines:
+                    line = line.split(" ")
+                    lrbt, c = xywh2lrbt(img_w, img_h, line)
+                    arr_label.append(lrbt)
+                    ori_clas.append(c)
+
+                all_label_len += len(lines)  # 所有标签的数量,为了后面计算所有图片的准确度
+
+                num = 0
+                for i in range(len(arr_label)):
+                    lt_x, lt_y = int(arr_label[i][0][0]), int(arr_label[i][0][1])
+                    rb_x, rb_y = int(arr_label[i][2][0]), int(arr_label[i][2][1])
+                    rect1 = [lt_x, lt_y, rb_x, rb_y]
+
+                    # 切割真实标签,并保存到本地
+                    cut_ori_img = img[lt_y:rb_y, lt_x:rb_x, :]
+                    save_cut_ori_img = cut_ori_img_path + str(img_name) + "_" + str(num) + ".jpg"
+                    cv2.imwrite(save_cut_ori_img, cut_ori_img)
+                    save_cut_ori_img = mdFile.new_inline_image(text='', path=md_cut_ori_img_path+ str(img_name) + "_" + str(num) + ".jpg")
+                    ori_cla = int(ori_clas[i])
+                    print("此时的ori_cla:", d_map[ori_cla])
+                    num += 1
+
+                    for j in range(len(result[0])):
+                        res = []
+                        predict_cla = int(result[0][j]['class'])
+                        print("此时的predict_cla:", d_map[predict_cla])
+                        if ori_cla == predict_cla or (ori_cla in [2, 3, 4] and predict_cla in [2, 3, 4]):
+                            rect2 = result[0][j]['bbox']
+                            iou = IOU(rect1, rect2)
+                            if iou >= 0.5:
+                                # 切割预测结果
+                                cut_predict_img = img[rect2[1]:rect2[3], rect2[0]:rect2[2], :]
+                                save_cut_predict_img = cut_predict_img_path + str(img_name) + "_" + str(
+                                    j) + ".jpg"
+                                cv2.imwrite(save_cut_predict_img, cut_predict_img)
+                                save_cut_predict_img = mdFile.new_inline_image(text='', path=md_cut_predict_img_path+ str(img_name) + "_" + str(j) + ".jpg")
+                                isTrue = True
+                                list_item.extend(
+                                    [d_map[ori_cla], d_map[predict_cla], save_cut_ori_img, save_cut_predict_img,
+                                     isTrue])
+                                each_acc_map[ori_cla] += 1
+                                all_acc += 1
+                                each_acc += 1
+                                break
+                        if j == len(result[0]) - 1:
+                            isTrue = False
+                            list_item.extend([d_map[ori_cla], " ", save_cut_ori_img, " ", isTrue])
+                            list_item.extend(res)
+
+                # 每一张图片的总准确度
+                each_acc = round(each_acc / len(ori_clas), 3)
+                # 每一类的准确度
+                code = round(each_acc_map[0] / ori_clas.count(0), 3) if ori_clas.count(0) else "这张图没有这个标签"
+                logo = round(each_acc_map[1] / ori_clas.count(1), 3) if ori_clas.count(1) else "这张图没有这个标签"
+                style = round(each_acc_map[2] / ori_clas.count(2), 3) if ori_clas.count(2) else "这张图没有这个标签"
+                table = round(each_acc_map[3] / ori_clas.count(3), 3) if ori_clas.count(3) else "这张图没有这个标签"
+                text = round(each_acc_map[4] / ori_clas.count(4), 3) if ori_clas.count(4) else "这张图没有这个标签"
+                title = round(each_acc_map[5] / ori_clas.count(5), 3) if ori_clas.count(5) else "这张图没有这个标签"
+
+                mdFile.new_line(str(image) + "—总准确率:" + str(each_acc))
+                mdFile.new_line(str(d_map[0]) + "—准确率:" + str(code))
+                mdFile.new_line(str(d_map[1]) + "—准确率:" + str(logo))
+                mdFile.new_line(str(d_map[2]) + "—准确率:" + str(style))
+                mdFile.new_line(str(d_map[3]) + "—准确率:" + str(table))
+                mdFile.new_line(str(d_map[4]) + "—准确率:" + str(text))
+                mdFile.new_line(str(d_map[5]) + "—准确率:" + str(title))
+
+                break
+
+        mdFile.new_line()
+        mdFile.new_table(columns=5, rows=len(list_item) // 5, text=list_item, text_align='center')
+        mdFile.create_md_file()
+
+    print(all_acc, all_label_len)
+    print("所有图片的预测准确度:", all_acc / all_label_len)

+ 110 - 0
cutimgs.py

@@ -0,0 +1,110 @@
+import cv2
+import os
+import matplotlib.pyplot as plt
+
+path = "./imgs/"
+cut_img_path = "./cut_img/"
+d_map = {0: 'code', 1: 'logo', 2: 'style', 3: 'table', 4: 'text', 5: 'title'}
+
+def makedir(dir):
+    if not os.path.exists(dir):
+        os.makedirs(dir)
+
+def xywh2lrbt(img_w, img_h, box):
+    c, x, y, w, h = int(box[0]), float(box[1]), float(box[2]), float(box[3]), float(box[4])
+    print(c, x, y, w, h)
+    print("+++++++++++++++++")
+    x = x * img_w  # 中心坐标x
+    w = w * img_w  # box的宽
+    y = y * img_h  # 中心坐标y
+    h = h * img_h  # box的高
+    print(c, x, y, w, h)
+
+    lt_x, lt_y = max(0, x - w / 2), max(0, y - h / 2)  # left_top_x, left_top_y
+    lb_x, lb_y = max(0, x - w / 2), min(img_h, y + h / 2)  # left_bottom_x, left_bottom_y
+    rb_x, rb_y = min(img_w, x + w / 2), min(img_h, y + h / 2)  # right_bottom_x, right_bottom_y
+    rt_x, rt_y = min(img_w, x + w / 2), max(0, y - h / 2)  # right_top_x, right_bottom_y
+
+    lrbt = [[lt_x, lt_y], [lb_x, lb_y], [rb_x, rb_y], [rt_x, rt_y]]
+    print(lrbt)
+
+    return lrbt, c
+
+
+img_cls = os.listdir(path)  # 获取train,valid,test目录名称
+print(img_cls)
+# 遍历每个目录,拿到图片和标签
+for img_cla in img_cls:
+    # print(img_cla)
+    images_path = path + img_cla + "/images/"
+    labels_path = path + img_cla + "/labels/"
+    list_images = os.listdir(images_path)
+    list_labels = os.listdir(labels_path)
+
+    for m in range(len(list_images)):
+        image = list_images[m]
+        img_name = os.path.splitext(image)[0]
+        if image != "784_jpg.rf.8a0d3a961866d9bda06a53d9bf2fbc40.jpg":
+            continue
+        for n in range(len(list_labels)):
+            label = list_labels[n]
+            label_name = os.path.splitext(label)[0]
+
+            if img_name != label_name:
+                continue
+            else:
+                # print(image)
+                # print(label)
+                # print(img_cla)
+                # print("==================")
+
+                img = cv2.imread(images_path + image)
+                img_h, img_w, _ = img.shape
+                arr_label = []  # labels的边界框,存储四个顶点坐标
+                ori_clas = []  # label的真实类别
+                with open(labels_path + label, mode="r", encoding="utf-8") as f:
+                    lines = f.readlines()
+
+                # some labels are [],-----> image not idea
+                if not lines:
+                    break
+
+                for line in lines:
+                    line = line.split(" ")
+                    lrbt, c = xywh2lrbt(img_w, img_h, line)
+                    arr_label.append(lrbt)
+                    ori_clas.append(c)
+
+                num = 0
+                for i in range(len(arr_label)):
+                    lt_x, lt_y = int(arr_label[i][0][0]), int(arr_label[i][0][1])
+                    rb_x, rb_y = int(arr_label[i][2][0]), int(arr_label[i][2][1])
+
+                    cls_num = ori_clas[i]
+                    cls = d_map[cls_num]
+                    print(cls)
+                    makedir(cut_img_path + cls)
+
+                    # 切割真实标签,并保存到本地
+                    # print(lt_y)
+                    # print(rb_y)
+                    # print(lt_x)
+                    # print(rb_x)
+                    if lt_x >= rb_x or lt_y >= rb_y:
+                        continue
+                    cut_ori_img = img[lt_y:rb_y, lt_x:rb_x, :]
+                    # plt.imshow(cut_ori_img[:,:,::-1])
+                    # plt.show()
+                    save_cut_ori_img = cut_img_path + cls + "/" + str(img_cla) + "_" + str(m) + "_" + str(cls) + "_" + str(num) + ".jpg"
+                    # cv2.imwrite(save_cut_ori_img, cut_ori_img)
+                    num += 1
+
+                break
+
+
+
+
+
+
+
+