1.数据集处理

将标签处理为 x中心点偏移量,y中心点偏移量,宽度,高度 (都进行归一化)

  1. with open(opt.train_annotation_path) as f:
  2. train_lines = f.readlines()
  3. with open(opt.val_annotation_path) as f:
  4. val_lines = f.readlines()
  5. num_train = len(train_lines)
  6. num_val = len(val_lines)
  7. train_dataset = YoloDataset(train_lines, opt.input_shape, train = True)
  8. train_dataloader = DataLoader(train_dataset, shuffle = True, batch_size = opt.batch_size, num_workers = 4, pin_memory=True,
  9. drop_last=True)
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from PIL import Image
  5. from torch.utils.data.dataset import Dataset
  6. class YoloDataset(Dataset):
  7. def __init__(self, annotation_lines, input_shape, train):
  8. super(YoloDataset, self).__init__()
  9. self.annotation_lines = annotation_lines
  10. self.input_shape = input_shape
  11. self.length = len(self.annotation_lines)
  12. self.train = train
  13. self.build_target = build_target()
  14. def __len__(self):
  15. return self.length
  16. def __getitem__(self, index):
  17. index = index % self.length
  18. #---------------------------------------------------#
  19. # 数据增强
  20. #---------------------------------------------------#
  21. image, boxes = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train )
  22. image = torch.from_numpy(image.astype(np.float32)).permute(2, 0, 1) / 255.
  23. boxes = np.array(boxes, dtype=np.float32)
  24. if len(boxes) != 0:
  25. boxes[:, [0, 2]] = boxes[:, [0, 2]] / self.input_shape[1]
  26. boxes[:, [1, 3]] = boxes[:, [1, 3]] / self.input_shape[0]
  27. boxes[:, 2:4] = boxes[:, 2:4] - boxes[:, 0:2]
  28. boxes[:, 0:2] = boxes[:, 0:2] + boxes[:, 2:4] / 2
  29. image, targets = self.build_target(image, boxes[:,:4], boxes[:,4])
  30. return image, targets
  31. def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
  32. line = annotation_line.split()
  33. image = Image.open(line[0])
  34. image = image.convert('RGB')
  35. iw, ih = image.size
  36. h, w = input_shape
  37. #------------------------------#
  38. # 获取边框信息
  39. #------------------------------#
  40. box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
  41. if not random:
  42. scale = min(w/iw, h/ih)
  43. nw = int(iw*scale)
  44. nh = int(ih*scale)
  45. dx = (w-nw)//2
  46. dy = (h-nh)//2
  47. #---------------------------------#
  48. # 添加灰度条
  49. #---------------------------------#
  50. image = image.resize((nw,nh), Image.BICUBIC)
  51. new_image = Image.new('RGB', (w,h), (128,128,128))
  52. new_image.paste(image, (dx, dy))
  53. image_data = np.array(new_image, np.float32)
  54. #---------------------------------#
  55. # 调整边框
  56. #---------------------------------#
  57. if len(box)>0:
  58. np.random.shuffle(box)
  59. box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
  60. box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
  61. box[:, 0:2][box[:, 0:2]<0] = 0
  62. box[:, 2][box[:, 2]>w] = w
  63. box[:, 3][box[:, 3]>h] = h
  64. box_w = box[:, 2] - box[:, 0]
  65. box_h = box[:, 3] - box[:, 1]
  66. box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
  67. if box.size == 0:
  68. print(line[0])
  69. if len(box)==0:
  70. print(line[0])
  71. return image_data, box
  72. #------------------------------------------#
  73. # 对图像进行缩放,并进行长和宽的扭曲
  74. #------------------------------------------#
  75. new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
  76. scale = self.rand(.25, 2)
  77. if new_ar < 1:
  78. nh = int(scale*h)
  79. nw = int(nh*new_ar)
  80. else:
  81. nw = int(scale*w)
  82. nh = int(nw/new_ar)
  83. image = image.resize((nw,nh), Image.BICUBIC)
  84. #------------------------------------------#
  85. # 添加灰度条
  86. #------------------------------------------#
  87. dx = int(self.rand(0, w-nw))
  88. dy = int(self.rand(0, h-nh))
  89. new_image = Image.new('RGB', (w,h), (128,128,128))
  90. new_image.paste(image, (dx, dy))
  91. image = new_image
  92. #------------------------------------------#
  93. # 翻转图像
  94. #------------------------------------------#
  95. flip = self.rand()<.5
  96. if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
  97. image_data = np.array(image, np.uint8)
  98. #---------------------------------#
  99. # 对图像进行色域变换
  100. #---------------------------------#
  101. r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
  102. hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
  103. dtype = image_data.dtype
  104. x = np.arange(0, 256, dtype=r.dtype)
  105. lut_hue = ((x * r[0]) % 180).astype(dtype)
  106. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  107. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  108. image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  109. image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
  110. #---------------------------------#
  111. # 调整边框
  112. #---------------------------------#
  113. if len(box)>0:
  114. np.random.shuffle(box)
  115. box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
  116. box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
  117. if flip: box[:, [0,2]] = w - box[:, [2,0]]
  118. box[:, 0:2][box[:, 0:2]<0] = 0
  119. box[:, 2][box[:, 2]>w] = w
  120. box[:, 3][box[:, 3]>h] = h
  121. box_w = box[:, 2] - box[:, 0]
  122. box_h = box[:, 3] - box[:, 1]
  123. box = box[np.logical_and(box_w>1, box_h>1)]
  124. return image_data, box
  125. def rand(self, a=0, b=1):
  126. return np.random.rand()*(b-a) + a
  127. def build_target():
  128. return TargetTransoform()
  129. class TargetTransoform(object):
  130. def __init__(self, target_shape=(7, 7, 30), class_nums=20, cell_nums=7):
  131. self.target_shape = target_shape
  132. self.class_nums = class_nums
  133. self.cell_nums = cell_nums
  134. def __call__(self, image, boxes, labels):
  135. """
  136. labels = [1,2,3,4]
  137. boxes = [0.2 0.3 0.4 0.8]
  138. return [self.S,self.S,self.B*5+self.C]
  139. """
  140. labels = np.array(labels,"int8")
  141. np_target = np.zeros(self.target_shape)
  142. np_class = np.zeros((len(boxes), self.class_nums))
  143. for i in range(len(labels)):
  144. np_class[i][labels[i]] = 1
  145. step = 1.0 / self.cell_nums
  146. for i in range(len(boxes)):
  147. box = boxes[i]
  148. label = np_class[i]
  149. cx, cy, w, h = box
  150. #获取属于哪个网格
  151. bx = int(cx // (step + 1e-5))
  152. by = int(cy // (step + 1e-5))
  153. #每个网格左上角的坐标
  154. cx = (cx % step) / step
  155. cy = (cy % step) / step
  156. box = [cx, cy, w, h]
  157. np_target[by][bx][:4] = box
  158. np_target[by][bx][4] = 1
  159. np_target[by][bx][5:9] = box
  160. np_target[by][bx][9] = 1
  161. np_target[by][bx][10:] = label
  162. return image, np_target
  163. if __name__ == "__main__":
  164. tain_path = "/home/users/user1/Documents/AI_Files/gitfies/pytorch-YOLO-v1/2007_train.txt"
  165. with open(tain_path) as f:
  166. train_lines = f.readlines()
  167. train_dataset = YoloDataset(train_lines, [448,448], train = True)
  168. image,Target = train_dataset[3054]

2. 损失函数

训练使得网络输出 x中心点偏移量,y中心点偏移量,宽度开根号,高度开根号 (都进行归一化)

  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. from torch.autograd import Variable
  5. from torch.nn import *
  6. class yoloLoss(Module):
  7. def __init__(self, num_class=20):
  8. super(yoloLoss, self).__init__()
  9. self.lambda_coord = 5
  10. self.lambda_noobj = 0.5
  11. self.S = 7
  12. self.B = 2
  13. self.C = num_class
  14. self.step = 1.0 / 7
  15. def compute_iou(self, box1, box2, index):
  16. box1 = torch.clone(box1)
  17. box2 = torch.clone(box2)
  18. box1 = self.conver_box(box1, index)#[x_c,y_c,w,h]
  19. box2 = self.conver_box(box2, index)
  20. box1[:, 2] = torch.pow(box1[:, 2],2)
  21. box1[:, 3] = torch.pow(box1[:, 3],2)
  22. x1, y1, w1, h1 = box1[:, 0]- box1[:, 2] / 2, box1[:, 1]- box1[:, 3] / 2, box1[:, 2], box1[:, 3]
  23. x2, y2, w2, h2 = box2[:, 0]- box2[:, 2] / 2, box2[:, 1]- box2[:, 3] / 2, box2[:, 2], box2[:, 3]
  24. inter_w = (w1 + w2) - (torch.max(x1 + w1, x2 + w2) - torch.min(x1, x2))
  25. inter_h = (h1 + h2) - (torch.max(y1 + h1, y2 + h2) - torch.min(y1, y2))
  26. inter_h = torch.clamp(inter_h, 0)
  27. inter_w = torch.clamp(inter_w, 0)
  28. inter = inter_w * inter_h
  29. union = w1 * h1 + w2 * h2 - inter
  30. return inter / union
  31. def conver_box(self, box, index):
  32. i, j = index
  33. box[:, 0], box[:, 1] = [(box[:, 0] + i) * self.step, (box[:, 1] + j) * self.step]
  34. box = torch.clamp(box, 0)
  35. return box
  36. def forward(self, pred, target):
  37. batch_size = pred.size(0)
  38. target_boxes = target[:, :, :, :10].contiguous().reshape(
  39. (-1, 7, 7, 2, 5))
  40. pred_boxes = pred[:, :, :, :10].contiguous().reshape((-1, 7, 7, 2, 5))
  41. target_cls = target[:, :, :, 10:]
  42. pred_cls = pred[:, :, :, 10:]
  43. #h获取含有目标的坐标
  44. obj_mask = (target_boxes[..., 4] > 0).byte()
  45. sig_mask = obj_mask[..., 1].bool()
  46. index = torch.where(sig_mask == True)
  47. #img_i代表第几张图片,y属于网格y坐标,x属于网格x坐标
  48. for img_i, y, x in zip(*index):
  49. img_i, y, x = img_i.item(), y.item(), x.item()
  50. pbox = pred_boxes[img_i, y, x]
  51. target_box = target_boxes[img_i, y, x]
  52. ious = self.compute_iou(pbox[:, :4], target_box[:, :4], [x, y])
  53. iou, max_i = ious.max(0)
  54. #将有目标的置信度标签且IOU较大的框框的置信度设为IOU
  55. target_boxes[img_i, y, x, max_i, 4] = iou.item()
  56. #另一个框框设为0
  57. target_boxes[img_i, y, x, 1 - max_i, 4] = 0
  58. obj_mask[img_i, y, x, 1 - max_i] = 0
  59. obj_mask = obj_mask.bool()
  60. noobj_mask = ~obj_mask
  61. noobj_loss = F.mse_loss(pred_boxes[noobj_mask][:, 4],
  62. target_boxes[noobj_mask][:, 4],
  63. reduction="sum")
  64. obj_loss = F.mse_loss(pred_boxes[obj_mask][:, 4],
  65. target_boxes[obj_mask][:, 4],
  66. reduction="sum")
  67. xy_loss = F.mse_loss(pred_boxes[obj_mask][:, :2],
  68. target_boxes[obj_mask][:, :2],
  69. reduction="sum")
  70. wh_loss = F.mse_loss(pred_boxes[obj_mask][:, 2:4],
  71. torch.sqrt(target_boxes[obj_mask][:, 2:4]),
  72. reduction="sum")
  73. class_loss = F.mse_loss(pred_cls[sig_mask],
  74. target_cls[sig_mask],
  75. reduction="sum")
  76. loss = dict(conf_loss=(obj_loss + self.lambda_noobj * noobj_loss) /
  77. batch_size,
  78. reg_loss=(self.lambda_coord * xy_loss +
  79. self.lambda_coord * wh_loss) / batch_size,
  80. cls_loss=class_loss / batch_size)
  81. return loss