create_md.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """
  2. function: 生成md文件
  3. """
  4. from mdutils.mdutils import MdUtils
  5. import os
  6. import cv2
  7. import numpy as np
  8. import requests
  9. d_map = {0: 'code', 1: 'logo', 2: 'style', 3: 'table', 4: 'text', 5: 'title'}
  10. images_path = "./train/images/"
  11. labels_path = "./train/labels/"
  12. cut_ori_img_path = "./train/cut_ori_imgs/"
  13. cut_predict_img_path = "./train/cut_predict_imgs/"
  14. md_cut_ori_img_path = "../train/cut_ori_imgs/"
  15. md_cut_predict_img_path = "../train/cut_predict_imgs/"
  16. md_path = "mdfiles/"
  17. # 标签是经过归一化的,需要变回来
  18. # xywh格式 ---> box四个顶点的坐标
  19. def xywh2lrbt(img_w, img_h, box):
  20. c, x, y, w, h = int(box[0]), float(box[1]), float(box[2]), float(box[3]), float(box[4])
  21. x = x * img_w # 中心坐标x
  22. w = w * img_w # box的宽
  23. y = y * img_h # 中心坐标y
  24. h = h * img_h # box的高
  25. lt_x, lt_y = x - w / 2, y - h / 2 # left_top_x, left_top_y
  26. lb_x, lb_y = x - w / 2, y + h / 2 # left_bottom_x, left_bottom_y
  27. rb_x, rb_y = x + w / 2, y + h / 2 # right_bottom_x, right_bottom_y
  28. rt_x, rt_y = x + w / 2, y - h / 2 # right_top_x, right_bottom_y
  29. lrbt = [[lt_x, lt_y], [lb_x, lb_y], [rb_x, rb_y], [rt_x, rt_y]]
  30. return lrbt, c
  31. def IOU(rect1, rect2):
  32. xmin1, ymin1, xmax1, ymax1 = rect1
  33. xmin2, ymin2, xmax2, ymax2 = rect2
  34. s1 = (xmax1 - xmin1) * (ymax1 - ymin1)
  35. s2 = (xmax2 - xmin2) * (ymax2 - ymin2)
  36. sum_area = s1 + s2
  37. left = max(xmin2, xmin1)
  38. right = min(xmax2, xmax1)
  39. top = max(ymin2, ymin1)
  40. bottom = min(ymax2, ymax1)
  41. if left >= right or top >= bottom:
  42. return 0
  43. intersection = (right - left) * (bottom - top)
  44. return intersection / (sum_area - intersection) * 1.0
  45. def send_requests(img_path):
  46. b_img_list = []
  47. b_img_list.append(('file_list', ('image.jpeg', open(img_path, 'rb'), 'image/jpeg')))
  48. payload = {'model_name': 'ocr-layout', 'img_size': 1920, 'download_image': False}
  49. response = requests.request("POST", "http://192.168.199.249:4869/detect", headers={}, data=payload,
  50. files=b_img_list)
  51. result = eval(str(response.text))
  52. result = eval(str(result['result']))
  53. return result
  54. if __name__ == "__main__":
  55. list_images = os.listdir(images_path)
  56. list_labels = os.listdir(labels_path)
  57. all_acc = 0
  58. all_label_len = 0
  59. # 遍历每张图片
  60. for m in range(len(list_images)):
  61. each_acc_map = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0} # 每张图片对应的各自预测正确数量
  62. each_acc = 0 # 每张图片总共预测正确的数量
  63. image = list_images[m]
  64. img_name = image.split(".")[0]
  65. # if image != "125-------------_jpg.rf.9a0d958c405ffb596b4022e13f0686b7.jpg":
  66. # continue
  67. for n in range(len(list_labels)):
  68. label = list_labels[n]
  69. label_name = label.split(".")[0]
  70. if img_name != label_name:
  71. continue
  72. if img_name == label_name:
  73. print(image)
  74. print(label)
  75. print("===========")
  76. # 每张图都创建一个md文件
  77. mdFile = MdUtils(file_name=md_path+'yq_0819_' + str(img_name), title='yq_0819_' + str(img_name) + '_iou>=0.5')
  78. list_item = ["ori_cla", "predict_cla", "cut_ori_img", "cut_predict_img", "isTrue"]
  79. result = send_requests(images_path + image)
  80. img = cv2.imread(images_path + image)
  81. img_h, img_w, _ = img.shape
  82. # 获取原图上的切割结果
  83. arr_label = [] # labels的边界框,存储四个顶点坐标
  84. ori_clas = [] # label的真实类别
  85. with open(labels_path + label, mode="r", encoding="utf-8") as f:
  86. lines = f.readlines()
  87. for line in lines:
  88. line = line.split(" ")
  89. lrbt, c = xywh2lrbt(img_w, img_h, line)
  90. arr_label.append(lrbt)
  91. ori_clas.append(c)
  92. all_label_len += len(lines) # 所有标签的数量,为了后面计算所有图片的准确度
  93. num = 0
  94. for i in range(len(arr_label)):
  95. lt_x, lt_y = int(arr_label[i][0][0]), int(arr_label[i][0][1])
  96. rb_x, rb_y = int(arr_label[i][2][0]), int(arr_label[i][2][1])
  97. rect1 = [lt_x, lt_y, rb_x, rb_y]
  98. # 切割真实标签,并保存到本地
  99. cut_ori_img = img[lt_y:rb_y, lt_x:rb_x, :]
  100. save_cut_ori_img = cut_ori_img_path + str(img_name) + "_" + str(num) + ".jpg"
  101. cv2.imwrite(save_cut_ori_img, cut_ori_img)
  102. save_cut_ori_img = mdFile.new_inline_image(text='', path=md_cut_ori_img_path+ str(img_name) + "_" + str(num) + ".jpg")
  103. ori_cla = int(ori_clas[i])
  104. print("此时的ori_cla:", d_map[ori_cla])
  105. num += 1
  106. for j in range(len(result[0])):
  107. res = []
  108. predict_cla = int(result[0][j]['class'])
  109. print("此时的predict_cla:", d_map[predict_cla])
  110. if ori_cla == predict_cla or (ori_cla in [2, 3, 4] and predict_cla in [2, 3, 4]):
  111. rect2 = result[0][j]['bbox']
  112. iou = IOU(rect1, rect2)
  113. if iou >= 0.5:
  114. # 切割预测结果
  115. cut_predict_img = img[rect2[1]:rect2[3], rect2[0]:rect2[2], :]
  116. save_cut_predict_img = cut_predict_img_path + str(img_name) + "_" + str(
  117. j) + ".jpg"
  118. cv2.imwrite(save_cut_predict_img, cut_predict_img)
  119. save_cut_predict_img = mdFile.new_inline_image(text='', path=md_cut_predict_img_path+ str(img_name) + "_" + str(j) + ".jpg")
  120. isTrue = True
  121. list_item.extend(
  122. [d_map[ori_cla], d_map[predict_cla], save_cut_ori_img, save_cut_predict_img,
  123. isTrue])
  124. each_acc_map[ori_cla] += 1
  125. all_acc += 1
  126. each_acc += 1
  127. break
  128. if j == len(result[0]) - 1:
  129. isTrue = False
  130. list_item.extend([d_map[ori_cla], " ", save_cut_ori_img, " ", isTrue])
  131. list_item.extend(res)
  132. # 每一张图片的总准确度
  133. each_acc = round(each_acc / len(ori_clas), 3)
  134. # 每一类的准确度
  135. code = round(each_acc_map[0] / ori_clas.count(0), 3) if ori_clas.count(0) else "这张图没有这个标签"
  136. logo = round(each_acc_map[1] / ori_clas.count(1), 3) if ori_clas.count(1) else "这张图没有这个标签"
  137. style = round(each_acc_map[2] / ori_clas.count(2), 3) if ori_clas.count(2) else "这张图没有这个标签"
  138. table = round(each_acc_map[3] / ori_clas.count(3), 3) if ori_clas.count(3) else "这张图没有这个标签"
  139. text = round(each_acc_map[4] / ori_clas.count(4), 3) if ori_clas.count(4) else "这张图没有这个标签"
  140. title = round(each_acc_map[5] / ori_clas.count(5), 3) if ori_clas.count(5) else "这张图没有这个标签"
  141. mdFile.new_line(str(image) + "—总准确率:" + str(each_acc))
  142. mdFile.new_line(str(d_map[0]) + "—准确率:" + str(code))
  143. mdFile.new_line(str(d_map[1]) + "—准确率:" + str(logo))
  144. mdFile.new_line(str(d_map[2]) + "—准确率:" + str(style))
  145. mdFile.new_line(str(d_map[3]) + "—准确率:" + str(table))
  146. mdFile.new_line(str(d_map[4]) + "—准确率:" + str(text))
  147. mdFile.new_line(str(d_map[5]) + "—准确率:" + str(title))
  148. break
  149. mdFile.new_line()
  150. mdFile.new_table(columns=5, rows=len(list_item) // 5, text=list_item, text_align='center')
  151. mdFile.create_md_file()
  152. print(all_acc, all_label_len)
  153. print("所有图片的预测准确度:", all_acc / all_label_len)