一、lovasz softmax魔改
    原lovasz softmax需要提供的标签为索引值,形状为(n,h,w),如果使用外部库的话计算metric等不方便。
    新实现如下:

    1. """
    2. Lovasz-Softmax and Jaccard hinge loss in PyTorch
    3. Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
    4. """
    5. from __future__ import print_function, division
    6. import torch
    7. from torch.autograd import Variable
    8. import torch.nn.functional as F
    9. import numpy as np
    10. try:
    11. from itertools import ifilterfalse
    12. except ImportError: # py3k
    13. from itertools import filterfalse as ifilterfalse
    14. from torch.nn.modules.loss import _Loss
    15. def lovasz_grad(gt_sorted):
    16. """
    17. Computes gradient of the Lovasz extension w.r.t sorted errors
    18. See Alg. 1 in paper
    19. """
    20. p = len(gt_sorted)
    21. gts = gt_sorted.sum()
    22. intersection = gts - gt_sorted.float().cumsum(0)
    23. union = gts + (1 - gt_sorted).float().cumsum(0)
    24. jaccard = 1. - intersection / union
    25. if p > 1: # cover 1-pixel case
    26. jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    27. return jaccard
    28. def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    29. """
    30. IoU for foreground class
    31. binary: 1 foreground, 0 background
    32. """
    33. if not per_image:
    34. preds, labels = (preds,), (labels,)
    35. ious = []
    36. for pred, label in zip(preds, labels):
    37. intersection = ((label == 1) & (pred == 1)).sum()
    38. union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
    39. if not union:
    40. iou = EMPTY
    41. else:
    42. iou = float(intersection) / float(union)
    43. ious.append(iou)
    44. iou = mean(ious) # mean accross images if per_image
    45. return 100 * iou
    46. def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    47. """
    48. Array of IoU for each (non ignored) class
    49. """
    50. if not per_image:
    51. preds, labels = (preds,), (labels,)
    52. ious = []
    53. for pred, label in zip(preds, labels):
    54. iou = []
    55. for i in range(C):
    56. if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
    57. intersection = ((label == i) & (pred == i)).sum()
    58. union = ((label == i) | ((pred == i) & (label != ignore))).sum()
    59. if not union:
    60. iou.append(EMPTY)
    61. else:
    62. iou.append(float(intersection) / float(union))
    63. ious.append(iou)
    64. ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
    65. return 100 * np.array(ious)
    66. # --------------------------- BINARY LOSSES ---------------------------
    67. def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    68. """
    69. Binary Lovasz hinge loss
    70. logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
    71. labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
    72. per_image: compute the loss per image instead of per batch
    73. ignore: void class id
    74. """
    75. if per_image:
    76. loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
    77. for log, lab in zip(logits, labels))
    78. else:
    79. loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    80. return loss
    81. def symmetric_lovasz_hinge(outputs, targets):
    82. return (lovasz_hinge(outputs, targets) +
    83. lovasz_hinge(-outputs, 1 - targets)) / 2
    84. def lovasz_hinge_flat(logits, labels):
    85. """
    86. Binary Lovasz hinge loss
    87. logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
    88. labels: [P] Tensor, binary ground truth labels (0 or 1)
    89. ignore: label to ignore
    90. """
    91. if len(labels) == 0:
    92. # only void pixels, the gradients should be 0
    93. return logits.sum() * 0.
    94. signs = 2. * labels.float() - 1.
    95. errors = (1. - logits * Variable(signs))
    96. errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    97. perm = perm.data
    98. gt_sorted = labels[perm]
    99. grad = lovasz_grad(gt_sorted)
    100. loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    101. return loss
    102. def flatten_binary_scores(scores, labels, ignore=None):
    103. """
    104. Flattens predictions in the batch (binary case)
    105. Remove labels equal to 'ignore'
    106. """
    107. scores = scores.view(-1)
    108. labels = labels.view(-1)
    109. if ignore is None:
    110. return scores, labels
    111. valid = (labels != ignore)
    112. vscores = scores[valid]
    113. vlabels = labels[valid]
    114. return vscores, vlabels
    115. class StableBCELoss(torch.nn.modules.Module):
    116. def __init__(self):
    117. super(StableBCELoss, self).__init__()
    118. def forward(self, input, target):
    119. neg_abs = - input.abs()
    120. loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
    121. return loss.mean()
    122. def binary_xloss(logits, labels, ignore=None):
    123. """
    124. Binary Cross entropy loss
    125. logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
    126. labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
    127. ignore: void class id
    128. """
    129. logits, labels = flatten_binary_scores(logits, labels, ignore)
    130. loss = StableBCELoss()(logits, Variable(labels.float()))
    131. return loss
    132. # --------------------------- MULTICLASS LOSSES ---------------------------
    133. def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    134. """
    135. Multi-class Lovasz-Softmax loss
    136. probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
    137. Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
    138. labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
    139. classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    140. per_image: compute the loss per image instead of per batch
    141. ignore: void class labels
    142. """
    143. if per_image:
    144. loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
    145. for prob, lab in zip(probas, labels))
    146. else:
    147. loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    148. return loss
    149. class lovasz_softmax_onehot(_Loss):
    150. def __init__(self):
    151. """
    152. 原lovasz softmax需要提供的标签为索引值(n,h,w),这里魔改成提供onehot编码,方便计算IoU等指标
    153. y_pred:(n,c,h,w),应为softmax(dim=1)的值
    154. y_true:(n,c,h,w),应为onehot编码
    155. """
    156. super().__init__()
    157. def forward(self, y_pred, y_true):
    158. y_true = torch.argmax(y_true,dim=1)
    159. loss = lovasz_softmax(y_pred,y_true)
    160. return loss
    161. def lovasz_softmax_flat(probas, labels, classes='present'):
    162. """
    163. Multi-class Lovasz-Softmax loss
    164. probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
    165. labels: [P] Tensor, ground truth labels (between 0 and C - 1)
    166. classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    167. """
    168. if probas.numel() == 0:
    169. # only void pixels, the gradients should be 0
    170. return probas * 0.
    171. C = probas.size(1)
    172. losses = []
    173. class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    174. for c in class_to_sum:
    175. fg = (labels == c).float() # foreground for class c
    176. if (classes is 'present' and fg.sum() == 0):
    177. continue
    178. if C == 1:
    179. if len(classes) > 1:
    180. raise ValueError('Sigmoid output possible only with 1 class')
    181. class_pred = probas[:, 0]
    182. else:
    183. class_pred = probas[:, c]
    184. errors = (Variable(fg) - class_pred).abs()
    185. errors_sorted, perm = torch.sort(errors, 0, descending=True)
    186. perm = perm.data
    187. fg_sorted = fg[perm]
    188. losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    189. return mean(losses)
    190. def flatten_probas(probas, labels, ignore=None):
    191. """
    192. Flattens predictions in the batch
    193. """
    194. if probas.dim() == 3:
    195. # assumes output of a sigmoid layer
    196. B, H, W = probas.size()
    197. probas = probas.view(B, 1, H, W)
    198. B, C, H, W = probas.size()
    199. probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
    200. labels = labels.view(-1)
    201. if ignore is None:
    202. return probas, labels
    203. valid = (labels != ignore)
    204. vprobas = probas[valid.nonzero().squeeze()]
    205. vlabels = labels[valid]
    206. return vprobas, vlabels
    207. def xloss(logits, labels, ignore=None):
    208. """
    209. Cross entropy loss
    210. """
    211. return F.cross_entropy(logits, Variable(labels), ignore_index=255)
    212. # --------------------------- HELPER FUNCTIONS ---------------------------
    213. def isnan(x):
    214. return x != x
    215. def mean(l, ignore_nan=False, empty=0):
    216. """
    217. nanmean compatible with generators.
    218. """
    219. l = iter(l)
    220. if ignore_nan:
    221. l = ifilterfalse(isnan, l)
    222. try:
    223. n = 1
    224. acc = next(l)
    225. except StopIteration:
    226. if empty == 'raise':
    227. raise ValueError('Empty mean')
    228. return empty
    229. for n, v in enumerate(l, 2):
    230. acc += v
    231. if n == 1:
    232. return acc
    233. return acc / n

    使用仅需import lovasz_softmax_onehot即可,注意这里的pred是softmax(dim=1)后的,label为onehot编码。

    criterion = L.lovasz_softmax_onehot()
    loss = criterion(pred,label)