图像分割损失函数loss 总结+代码

汇总语义分割中常用的损失函数:

  1. cross entropy loss
  2. weighted loss
  3. focal loss
  4. dice soft loss
  5. soft iou loss
  6. Tversky Loss
  7. Generalized Dice Loss
  8. Boundary Loss
  9. Exponential Logarithmic Loss
  10. Focal Tversky Loss
  11. Sensitivity Specificity Loss
  12. Shape-aware Loss
  13. Hausdorff Distance Loss

参考论文Medical Image Segmentation Using Deep Learning:A Survey

论文地址:A survey of loss functions for semantic segmentation
代码地址https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions
项目推荐https://github.com/JunMa11/SegLoss

图像分割一直是一个活跃的研究领域,因为它有可能修复医疗领域的漏洞,并帮助大众。在过去的5年里,各种论文提出了不同的目标损失函数,用于不同的情况下,如偏差数据,稀疏分割等。

图像分割可以定义为像素级别的分类任务。图像由各种像素组成,这些像素组合在一起定义了图像中的不同元素,因此将这些像素分类为一类元素的方法称为语义图像分割。在设计基于复杂图像分割的深度学习架构时,通常会遇到了一个至关重要的选择,即选择哪个损失/目标函数,因为它们会激发算法的学习过程。损失函数的选择对于任何架构学习正确的目标都是至关重要的,因此自2012年以来,各种研究人员开始设计针对特定领域的损失函数,以为其数据集获得更好的结果。

这些损失函数可大致分为4类:基于分布的损失函数,基于区域的损失函数,基于边界的损失函数和基于复合的损失函数( Distribution-based,Region-based,  Boundary-based,  and  Compounded)

1、cross entropy loss

用于图像语义分割任务的最常用损失函数是像素级别的交叉熵损失,这种损失会逐个检查每个像素,将对每个像素类别的预测结果(概率分布向量)与我们的独热编码标签向量进行比较。

假设我们需要对每个像素的预测类别有5个,则预测的概率分布向量长度为5:

每个像素对应的损失函数为:

整个图像的损失就是对每个像素的损失求平均值。

特别注意的是,binary entropy loss 是针对类别只有两个的情况,简称 bce loss,损失函数公式为:

#二值交叉熵,这里输入要经过sigmoid处理  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
nn.BCELoss(F.sigmoid(input), target)  
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层  
nn.CrossEntropyLoss(input, target)

2、weighted loss

由于交叉熵损失会分别评估每个像素的类别预测,然后对所有像素的损失进行平均,因此我们实质上是在对图像中的每个像素进行平等地学习。如果多个类在图像中的分布不均衡,那么这可能导致训练过程由像素数量多的类所主导,即模型会主要学习数量多的类别样本的特征,并且学习出来的模型会更偏向将像素预测为该类别。

FCN论文和U-Net论文中针对这个问题,对输出概率分布向量中的每个值进行加权,即希望模型更加关注数量较少的样本,以缓解图像中存在的类别不均衡问题。

比如对于二分类,正负样本比例为1: 99,此时模型将所有样本都预测为负样本,那么准确率仍有99%这么高,但其实该模型没有任何使用价值。

为了平衡这个差距,就对正样本和负样本的损失赋予不同的权重,带权重的二分类损失函数公式如下:

要减少假阴性样本的数量,可以增大 pos_weight;要减少假阳性样本的数量,可以减小 pos_weight。

class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):  
   """  
   Network has to have NO NONLINEARITY!  
   """  
   def __init__(self, weight=None):  
       super(WeightedCrossEntropyLoss, self).__init__()  
       self.weight = weight  
  
   def forward(self, inp, target):  
       target = target.long()  
       num_classes = inp.size()[1]  
  
       i0 = 1  
       i1 = 2  
  
       while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once  
           inp = inp.transpose(i0, i1)  
           i0 += 1  
           i1 += 1  
  
       inp = inp.contiguous()  
       inp = inp.view(-1, num_classes)  
  
       target = target.view(-1,)  
       wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)  
  
       return wce_loss(inp, target)

3、focal loss

上面针对不同类别的像素数量不均衡提出了改进方法,但有时还需要将像素分为难学习和容易学习这两种样本。

容易学习的样本模型可以很轻松地将其预测正确,模型只要将大量容易学习的样本分类正确,loss就可以减小很多,从而导致模型不怎么顾及难学习的样本,所以我们要想办法让模型更加关注难学习的样本。

对于较难学习的样本,将 bce loss 修改为:

其中的 γ 通常设置为2。

通过这种修改,就可以使模型更加专注于学习难学习的样本。

而将这个修改和对正负样本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:

class FocalLoss(nn.Module):  
   """  
   copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py  
   This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in  
   'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'  
       Focal_Loss= -1*alpha*(1-pt)*log(pt)  
   :param num_class:  
   :param alpha: (tensor) 3D or 4D the scalar factor for this criterion  
   :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more  
                   focus on hard misclassified example  
   :param smooth: (float,double) smooth value when cross entropy  
   :param balance_index: (int) balance class index, should be specific when alpha is float  
   :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.  
   """  
  
   def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):  
       super(FocalLoss, self).__init__()  
       self.apply_nonlin = apply_nonlin  
       self.alpha = alpha  
       self.gamma = gamma  
       self.balance_index = balance_index  
       self.smooth = smooth  
       self.size_average = size_average  
  
       if self.smooth is not None:  
           if self.smooth < 0 or self.smooth > 1.0:  
               raise ValueError('smooth value should be in [0,1]')  
  
   def forward(self, logit, target):  
       if self.apply_nonlin is not None:  
           logit = self.apply_nonlin(logit)  
       num_class = logit.shape[1]  
  
       if logit.dim() > 2:  
           # N,C,d1,d2 -> N,C,m (m=d1*d2*...)  
           logit = logit.view(logit.size(0), logit.size(1), -1)  
           logit = logit.permute(0, 2, 1).contiguous()  
           logit = logit.view(-1, logit.size(-1))  
       target = torch.squeeze(target, 1)  
       target = target.view(-1, 1)  
       # print(logit.shape, target.shape)  
       #   
       alpha = self.alpha  
  
       if alpha is None:  
           alpha = torch.ones(num_class, 1)  
       elif isinstance(alpha, (list, np.ndarray)):  
           assert len(alpha) == num_class  
           alpha = torch.FloatTensor(alpha).view(num_class, 1)  
           alpha = alpha / alpha.sum()  
       elif isinstance(alpha, float):  
           alpha = torch.ones(num_class, 1)  
           alpha = alpha * (1 - self.alpha)  
           alpha[self.balance_index] = self.alpha  
  
       else:  
           raise TypeError('Not support alpha type')  
         
       if alpha.device != logit.device:  
           alpha = alpha.to(logit.device)  
  
       idx = target.cpu().long()  
  
       one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()  
       one_hot_key = one_hot_key.scatter_(1, idx, 1)  
       if one_hot_key.device != logit.device:  
           one_hot_key = one_hot_key.to(logit.device)  
  
       if self.smooth:  
           one_hot_key = torch.clamp(  
               one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)  
       pt = (one_hot_key * logit).sum(1) + self.smooth  
       logpt = pt.log()  
  
       gamma = self.gamma  
  
       alpha = alpha[idx]  
       alpha = torch.squeeze(alpha)  
       loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt  
  
       if self.size_average:  
           loss = loss.mean()  
       else:  
           loss = loss.sum()  
       return loss

4、dice soft loss

语义分割任务中常用的还有一个基于 Dice 系数的损失函数,该系数实质上是两个样本之间重叠的度量。此度量范围为 0~1,其中 Dice 系数为1表示完全重叠。Dice 系数最初是用于二进制数据的,可以计算为:

|A∩B| 代表集合A和B之间的公共元素,并且 |A| 代表集合A中的元素数量(对于集合B同理)。

对于在预测的分割掩码上评估 Dice 系数,我们可以将 |A∩B| 近似为预测掩码和标签掩码之间的逐元素乘法,然后对结果矩阵求和。

计算 Dice 系数的分子中有一个2,那是因为分母中对两个集合的元素个数求和,两个集合的共同元素被加了两次。 为了设计一个可以最小化的损失函数,可以简单地使用 1−Dice。 这种损失函数被称为 soft Dice loss,这是因为我们直接使用预测出的概率,而不是使用阈值将其转换成一个二进制掩码。

Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。

对于神经网络的输出,分子与我们的预测和标签之间的共同激活有关,而分母分别与每个掩码中的激活数量有关,这具有根据标签掩码的尺寸对损失进行归一化的效果。

对于每个类别的mask,都计算一个 Dice 损失:

将每个类的 Dice 损失求和取平均,得到最后的 Dice soft loss。

下面是代码实现:

def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 
    ''' 
    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
    Assumes the `channels_last` format.
  
    # Arguments
        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) 
        epsilon: Used for numerical stability to avoid divide by zero errors
    
    # References
        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation 
        https://arxiv.org/abs/1606.04797
        More details on Dice loss formulation 
        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
        
        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
    '''
    
    # skip the batch and class axis for calculating Dice score
    axes = tuple(range(1, len(y_pred.shape)-1)) 
    numerator = 2. * np.sum(y_pred * y_true, axes)
    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
    
    return 1 - np.mean(numerator / (denominator + epsilon)) # average over classes and batch

5、soft IoU loss

前面我们知道计算 Dice 系数的公式,其实也可以表示为:

其中 TP 为真阳性样本,FP 为假阳性样本,FN 为假阴性样本。分子和分母中的 TP 样本都加了两次。

IoU 的计算公式和这个很像,区别就是 TP 只计算一次:

和 Dice soft loss 一样,通过 IoU 计算损失也是使用预测的概率值:

其中 C 表示总的类别数。

6、Tversky Loss

论文地址为:https://arxiv.org/pdf/1706.05… 

医学影像中存在很多的数据不平衡现象,使用不平衡数据进行训练会导致严重偏向高精度但低召回率(sensitivity)的预测,这是不希望的,特别是在医学应用中,假阴性比假阳性更难容忍。本文提出了一种基于Tversky指数的广义损失函数,解决了三维全卷积深神经网络训练中数据不平衡的问题,在精度和召回率之间取得了较好的折衷。

Dice loss的正则化版本,以控制假阳性和假阴性对损失函数的贡献,TL被定义为

class TverskyLoss(nn.Module):  
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,  
                square=False):  
       """  
       paper: https://arxiv.org/pdf/1706.05721.pdf  
       """  
       super(TverskyLoss, self).__init__()  
  
       self.square = square  
       self.do_bg = do_bg  
       self.batch_dice = batch_dice  
       self.apply_nonlin = apply_nonlin  
       self.smooth = smooth  
       self.alpha = 0.3  
       self.beta = 0.7  
  
   def forward(self, x, y, loss_mask=None):  
       shp_x = x.shape  
  
       if self.batch_dice:  
           axes = [0] + list(range(2, len(shp_x)))  
       else:  
           axes = list(range(2, len(shp_x)))  
  
       if self.apply_nonlin is not None:  
           x = self.apply_nonlin(x)  
  
       tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)  
  
  
       tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)  
  
       if not self.do_bg:  
           if self.batch_dice:  
               tversky = tversky[1:]  
           else:  
               tversky = tversky[:, 1:]  
       tversky = tversky.mean()  
  
       return -tversky  

7、Generalized Dice Loss

Dice loss虽然一定程度上解决了分类失衡的问题,但却不利于严重的分类不平衡。例如小目标存在一些像素的预测误差,这很容易导致Dice的值发生很大的变化。Sudre等人提出了Generalized Dice Loss (GDL)

GDL优于Dice损失,因为不同的区域对损失有相似的贡献,并且GDL在训练过程中更稳定和鲁棒。

8、Boundary Loss

为了解决类别不平衡的问题,Kervadec等人[95]提出了一种新的用于脑损伤分割的边界损失。该损失函数旨在最小化分割边界和标记边界之间的距离。作者在两个没有标签的不平衡数据集上进行了实验。结果表明,Dice los和Boundary los的组合优于单一组合。复合损失的定义为

其中第一部分是一个标准的Dice los,它被定义为

第二部分是Boundary los,它被定义为

9、Exponential Logarithmic Loss


在(9)中,加权Dice los实际上是得到的Dice值除以每个标签的和,对不同尺度的对象达到平衡。因此,Wong等人结合focal loss [96] 和dice loss,提出了用于脑分割的指数对数损失(EXP损失),以解决严重的类不平衡问题。通过引入指数形式,可以进一步控制损失函数的非线性,以提高分割精度。EXP损失函数的定义为

其中,两个新的参数权重分别用ωdice和ωcross表示。Ldice是指数对数骰子损失,而交叉损失是交叉熵损失

其中x是像素位置,i是标签,l是位置x处的地面真值。pi(x)是从softmax输出的概率值。
在(17)中,fk是标签k出现的频率,该参数可以减少更频繁出现的标签的影响。γDice和γcross都用于增强损失函数的非线性。

10.Focal Tversky Loss

与“Focal loss”相似,后者着重于通过降低易用/常见损失的权重来说明困难的例子。Focal Tversky Loss还尝试借助γ系数来学习诸如在ROI(感兴趣区域)较小的情况下的困难示例,如下所示:

class FocalTversky_loss(nn.Module):  
   """  
   paper: https://arxiv.org/pdf/1810.07842.pdf  
   author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65  
   """  
   def __init__(self, tversky_kwargs, gamma=0.75):  
       super(FocalTversky_loss, self).__init__()  
       self.gamma = gamma  
       self.tversky = TverskyLoss(**tversky_kwargs)  
  
   def forward(self, net_output, target):  
       tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)  
       focal_tversky = torch.pow(tversky_loss, self.gamma)  
       return focal_tversky  

11、Sensitivity Specificity Loss

首先敏感性就是召回率,检测出确实有病的能力:

640-8.png

特异性,检测出确实没病的能力:

640-9.png

而Sensitivity Specificity Loss为:

640-10.png
image.png
class SSLoss(nn.Module):  
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,  
                square=False):  
       """  
       Sensitivity-Specifity loss  
       paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf  
       tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392  
       """  
       super(SSLoss, self).__init__()  
  
       self.square = square  
       self.do_bg = do_bg  
       self.batch_dice = batch_dice  
       self.apply_nonlin = apply_nonlin  
       self.smooth = smooth  
       self.r = 0.1 # weight parameter in SS paper  
  
   def forward(self, net_output, gt, loss_mask=None):  
       shp_x = net_output.shape  
       shp_y = gt.shape  
       # class_num = shp_x[1]  
         
       with torch.no_grad():  
           if len(shp_x) != len(shp_y):  
               gt = gt.view((shp_y[0], 1, *shp_y[1:]))  
  
           if all([i == j for i, j in zip(net_output.shape, gt.shape)]):  
               # if this is the case then gt is probably already a one hot encoding  
               y_onehot = gt  
           else:  
               gt = gt.long()  
               y_onehot = torch.zeros(shp_x)  
               if net_output.device.type == "cuda":  
                   y_onehot = y_onehot.cuda(net_output.device.index)  
               y_onehot.scatter_(1, gt, 1)  
  
       if self.batch_dice:  
           axes = [0] + list(range(2, len(shp_x)))  
       else:  
           axes = list(range(2, len(shp_x)))  
  
       if self.apply_nonlin is not None:  
           softmax_output = self.apply_nonlin(net_output)  
         
       # no object value  
       bg_onehot = 1 - y_onehot  
       squared_error = (y_onehot - softmax_output)**2  
       specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)  
       sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)  
  
       ss = self.r * specificity_part + (1-self.r) * sensitivity_part  
  
       if not self.do_bg:  
           if self.batch_dice:  
               ss = ss[1:]  
           else:  
               ss = ss[:, 1:]  
       ss = ss.mean()  
  
       return ss

12、Log-Cosh Dice Loss

Dice系数是一种用于评估分割输出的度量标准。它也已修改为损失函数,因为它可以实现分割目标的数学表示。但是由于其非凸性,它多次都无法获得最佳结果。Lovsz-softmax损失旨在通过添加使用Lovsz扩展的平滑来解决非凸损失函数的问题。同时,Log-Cosh方法已广泛用于基于回归的问题中,以平滑曲线。

640.png
640-1.png

将Cosh(x)函数和Log(x)函数合并,可以得到Log-Cosh Dice Loss:

640-2.png
def log_cosh_dice_loss(self, y_true, y_pred):  
       x = self.dice_loss(y_true, y_pred)  
       return tf.math.log((torch.exp(x) + torch.exp(-x)) / 2.0)  

13、Hausdorff Distance Loss

Hausdorff Distance Loss(HD)是分割方法用来跟踪模型性能的度量。它定义为:

640-4.png

任何分割模型的目的都是为了最大化Hausdorff距离,但是由于其非凸性,因此并未广泛用作损失函数。有研究者提出了基于Hausdorff距离的损失函数的3个变量,它们都结合了度量用例,并确保损失函数易于处理。

class HDDTBinaryLoss(nn.Module):  
   def __init__(self):  
       """  
       compute haudorff loss for binary segmentation  
       https://arxiv.org/pdf/1904.10030v1.pdf          
       """  
       super(HDDTBinaryLoss, self).__init__()  
  
  
   def forward(self, net_output, target):  
       """  
       net_output: (batch_size, 2, x,y,z)  
       target: ground truth, shape: (batch_size, 1, x,y,z)  
       """  
       net_output = softmax_helper(net_output)  
       pc = net_output[:, 1, ...].type(torch.float32)  
       gt = target[:,0, ...].type(torch.float32)  
       with torch.no_grad():  
           pc_dist = compute_edts_forhdloss(pc.cpu().numpy()>0.5)  
           gt_dist = compute_edts_forhdloss(gt.cpu().numpy()>0.5)  
       # print('pc_dist.shape: ', pc_dist.shape)  
         
       pred_error = (gt - pc)**2  
       dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8)  
  
       dist = torch.from_numpy(dist)  
       if dist.device != pred_error.device:  
           dist = dist.to(pred_error.device).type(torch.float32)  
  
       multipled = torch.einsum("bxyz,bxyz->bxyz", pred_error, dist)  
       hd_loss = multipled.mean()  
  
       return hd_loss

总结:

交叉熵损失把每个像素都当作一个独立样本进行预测,而 dice loss 和 iou loss 则以一种更“整体”的方式来看待最终的预测输出。

这两类损失是针对不同情况,各有优点和缺点,在实际应用中,可以同时使用这两类损失来进行互补。

Deeplab v3

论文地址: https://arxiv.org/abs/1706.05587

Deeplab v3是v2版本的进一步升级,作者们在对空洞卷积重新思考的基础上,进一步对Deeplab系列的基本框架进行了优化,去掉了v1和v2版本中一直坚持的CRF后处理模块,升级了主干网络和ASPP模块,使得网络能够更好地处理语义分割中的多尺度问题。提出Deeplab v3的论文为Rethinking Atrous Convolution for Semantic Image Segmentation,是Deeplab系列后期网络的代表模型。

DeepLab V3的改进主要包括以下几方面:

1)提出了更通用的框架,适用于任何网络

2)复制了ResNet最后的block,并级联起来

3)在ASPP中使用BN层

4)去掉了CRF

随着语义分割的发展,逐渐有两大问题亟待解决:一个是连续的池化和卷积步长导致的下采样图像信息丢失问题,这个问题已经通过空洞卷积扩大感受野得到比较好的处理;另外一个则是多尺度和利用上下文信息问题。论文中分别回顾了四种基于多尺度和上下文信息进行语义分割的方法,如图1所示,包括图像金字塔、编解码架构、深度空洞卷积网络以及空间金字塔池化,这四种方法各有优缺点,ASPP可以算是深度空洞卷积和空间金字塔池化的一种结合,Deeplab v3在v2的ASPP基础上,进一步探索了空洞卷积在多尺度和上下文信息中的作用。

Deeplab v3可作为通用框架融入到任意网络结构中,具体地,以串行方式设计空洞卷积模块,复制ResNet的最后一个卷积块,并将复制后的卷积块以串行方式进行级联,如图2所示。DeepLab V3将空洞卷积应用在级联模块。具体来说,我们取ResNet中最后一个block,在下图中为block4,并在其后面增加级联模块。

在卷积块串行级联基础上,Deeplab v3又对ASPP模块进行并行级联,v3对ASPP模块进行了升级,相较于v2版本加入了批归一化(Batch Normalization,BN),通过组织不同的空洞扩张率的卷积块,同时加入图像级特征,能够更好地捕捉多尺度上下文信息,并且也能够更容易训练,如图3所示。

1)ASPP中应用了BN层

2)随着采样率的增加,滤波器中有效的权重减少了(有效权重减少,难以捕获原距离信息,这要求合理控制采样率的设置)

3)使用模型最后的特征映射的全局平均池化(为了克服远距离下有效权重减少的问题)

总结来看,Deeplab v3更充分的利用了空洞卷积来获取大范围的图像上下文信息。具体包括:多尺度信息编码、带有逐步翻倍的空洞扩张率的级联模块以及带有图像级特征的ASPP模块。实验结果表明,该模型在 PASCAL VOC数据集上相较于v2版本有了显着进步,取得了当时SOTA精度水平。

Deeplab v3的PyTorch实现可参考:

https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py

代码实现:

class DeeplabV3(ResNet):

    def __init__(self, n_class, block, layers, pyramids, grids, output_stride=16):
        self.inplanes = 64
        super(DeeplabV3, self).__init__()
        if output_stride == 16:
            strides = [1, 2, 2, 1]
            rates = [1, 1, 1, 2]
        elif output_stride == 8:
            strides = [1, 2, 1, 1]
            rates = [1, 1, 2, 2]
        else:
            raise NotImplementedError

        # Backbone Modules
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)   # h/4, w/4

        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0]) # h/4, w/4
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1]) # h/8, w/8
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2]) # h/16,w/16
        self.layer4 = self._make_MG_unit(block, 512, blocks=grids, stride=strides[3], rate=rates[3])# h/16,w/16

        # Deeplab Modules
        self.aspp1 = ASPP_module(2048, 256, rate=pyramids[0])  
        self.aspp2 = ASPP_module(2048, 256, rate=pyramids[1])
        self.aspp3 = ASPP_module(2048, 256, rate=pyramids[2])
        self.aspp4 = ASPP_module(2048, 256, rate=pyramids[3])

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(2048, 256, kernel_size=1, stride=1, bias=False),
                                             nn.BatchNorm2d(256),
                                             nn.ReLU())

        # get result features from the concat
        self._conv1 = nn.Sequential(nn.Conv2d(1280, 256, kernel_size=1, stride=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    nn.ReLU())

        # generate the final logits
        self._conv2 = nn.Conv2d(256, n_class, kernel_size=1, bias=False)

        self.init_weight()

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)

        # image-level features
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self._conv1(x)
        x = self._conv2(x)

        x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x

其中重要的_makelayer, _make_MG_unit和ASSP模块实现如下:

def _make_layer(self, block, planes, blocks, stride=1, rate=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, rate, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, rate=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, rate=blocks[0] * rate, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1, rate=blocks[i] * rate))

        return nn.Sequential(*layers)


class ASPP_module(nn.Module):
    def __init__(self, inplanes, planes, rate):
        super(ASPP_module, self).__init__()
        if rate == 1:
            kernel_size = 1
            padding = 0
        else:
            kernel_size = 3
            padding = rate
        self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=rate, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_convolution(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

训练策略:

Crop size:

  • 为了大采样率的空洞卷积能够有效,需要较大的图片大小;否则,大采样率的空洞卷积权值就会主要用于padding区域。
  • 在Pascal VOC 2012数据集的训练和测试中我们采用了513的裁剪尺寸。

Batch normalization:

  • 我们在ResNet之上添加的模块都包括BN层
  • 当output_stride=16时,采用batchsize=16,同时BN层的参数做参数衰减0.9997。
  • 在增强的数据集上,以初始学习率0.007训练30K后,冻结BN层参数,然后采用output_stride=8,再使用初始学习率0.001在PASCAL官方的数据集上训练30K。
  • 训练output_stride=16比output_stride=8要快很多,因为其中间的特征映射在空间上小四倍。但output_stride=16在特征映射上相对粗糙,快是因为牺牲了精度。

Upsampling logits:

  • 在先前的工作上,我们是将output_stride=8的输出与Ground Truth下采样8倍做比较。
  • 现在我们发现保持Ground Truth更重要,故我们是将最终的输出上采样8倍与完整的Ground Truth比较。

Data augmentation:

在训练阶段,随机缩放输入图像(从0.5到2.0)和随机左-右翻转

Deeplab v3+

论文地址 https://arxiv.org/abs/1706.05587

Deeplab v3+是Deeplab系列最后一个网络结构,也是基于空洞卷积和多尺度系列模型的集大成者。相较于Deeplab v3,v3+版本参考了UNet系列网络,对基于空洞卷积的Deeplab网络引入了编解码结构,一定程度上来讲,Deeplab v3+是编解码和多尺度这两大系列网络的一个大融合,在很长一段时间内代表了自然图像语义分割的SOTA水平的分割模型。提出Deeplab v3+的论文为Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation,至今仍然是最常用的一个语义分割网络模型。

对于语义分割问题,尽管各种网络模型很多,但Deeplab v3+的作者们认为迄今为止仅有两大主流设计:一个是以UNet为代表的编解码结构,另一个就是以Deeplab为代表的ASPP和多尺度设计,前者以获取图像中目标对象细致的图像边界见长,后者则更擅长捕捉图像中丰富的上下文多尺度信息。鉴于此,Deeplab v3+将编解码结构引入到v3网络中,Deeplab v3直接作为编码器,然后再加入解码器设计,这样就构成了一个SPP(空间金字塔池化)+Atrous Conv(空洞卷积)+Encoder-Decoder(编解码)包含众多丰富元素的组合结构。这三种结构如下图所示。

完整的Deeplab v3+网络结构如下图所示。可以看到,输入图像先进入有深度空洞卷积构成的编码器部分中,由并行的ASPP模块构建并且融入图像级的池化特征,然后再进行合并并通过1*1卷积降低通道数后得到编码器的输出,编码器部分没有做特别的改动,直接使用了Deeplab v3的结构作为编码器,旨在提取图像的多尺度上下文信息。解码器部分则是做了一些特别的设计,编码器输出先经过4倍双线性插值上采样,同时也从编码器部分链接浅层的图像特征到解码器部分,并经1*1卷积降维后与4倍上采样的输出进行拼接,最后经过一个3*3卷积后再经过一次4倍上采样后得到最终的分割输出结果。

另外,Deeplab v3+除了延用之前的ResNet系列作为backbone之外,也尝试了以轻量级着称的Xception网络,具体做法就是将深度可分离卷积(Depth separable convolution)引入到空洞卷积中,能够极大的减少计算量,虽然有一定的精度损失,但在追求速度性能上不失为一种非常好的选择。下面简单介绍一下深度可分离卷积。

从维度的角度看,卷积核可以看成是一个空间维(宽和高)和通道维的组合,而卷积操作则可以视为空间相关性和通道相关性的联合映射。从Inception的1*1卷积来看,卷积中的空间相关性和通道相关性是可以解耦的,将它们分开进行映射,可能会达到更好的效果。

深度可分离卷积是在1*1卷积基础上的一种创新。主要包括两个部分:深度卷积和1*1卷积。深度卷积的目的在于对输入的每一个通道都单独使用一个卷积核对其进行卷积,也就是通道分离后再组合。1*1卷积的目的则在于加强深度。下面以一个例子来看一下深度可分离卷积:假设我们用128个3*3*3的滤波器对一个7*7*3的输入进行卷积,可得到5*5*128的输出,其计算量为5*5*128*3*3*3=86400,如下图所示。

然后我们看如何使用深度可分离卷积来实现同样的结果。深度可分离卷积的第一步是深度卷积。这里的深度卷积,就是分别用3个3*3*1的滤波器对输入的3个通道分别做卷积,也就是说要做3次卷积,每次卷积都有一个5*5*1的输出,组合在一起便是5*5*3的输出。现在为了拓展深度达到128,我们需要执行深度可分离卷积的第二步:1*1卷积。现在我们用128个1*1*3的滤波器对5*5*3进行卷积,就可以得到5*5*128的输出。完整过程如下图所示。

看一下深度可分离卷积的计算量如何。第一步深度卷积的计算量:5*5*1*3*3*1*3=675。第二步11卷积的计算量:5*5*128*1*1*3=9600,合计计算量为10275次。可见,相同的卷积计算输出,深度可分离卷积要比常规卷积节省12倍的计算成本。典型的应用深度可分离卷积的网络模型包括Xception和MobileNet等。本质上而言,Xception就是应用了深度可分离卷积的Inception网络。

Deeplab v3+在PASCAL VOC和Cityscapes等公开数据集上均取得了SOTA的结果,即使在深度学习语义分割发展日新月异发展的今天,Deeplab v3+仍然不失为一个非常好的语义分割解决方案。

关于Deeplab系列各版本的技术要点总结如下表所示。

Deeplab v2

DeepLabv2:
DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs

TPAMI 2018

Deeplab v2 严格上算是Deeplab v1版本的一次不大的更新,在v1的空洞卷积和CRF基础上,重点关注了网络对于多尺度问题的适用性。多尺度问题一直是目标检测和语义分割任务的重要挑战之一,以往实现多尺度的惯常做法是对同一张图片进行不同尺寸的缩放后获取对应的卷积特征图,然后将不同尺寸的特征图分别上采样后再融合来获取多尺度信息,但这种做法最大的缺点就是计算开销太大。Deeplab v2借鉴了空间金字塔池化(Spatial Pyramaid Pooling, SPP)的思路,提出了基于空洞卷积的空间金字塔池化(Atrous Spatial Pyramaid Pooling, ASPP),这也是Deeplab v2最大的亮点。提出Deeplab v2的论文为DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs,是Deeplab系列网络中前期结构的重要代表。

ASPP来源于R-CNN(Regional CNN)目标检测领域中SPP结构,该方法表明任意尺度的图像区域可以通过对单一尺度提取的卷积特征进行重采样而准确有效地分类。ASPP在其基础上将普通卷积改为空洞卷积,通过使用多个不同扩张率且并行的空洞卷积进行特征提取,最后在对每个分支进行融合。ASPP结构如下图所示。

除了ASPP之外,Deeplab v2还将v1中VGG-16的主干网络换成了ResNet-101,算是对编码器的一次升级,使其具备更强的特征提取能力。Deeplab v2在PASCAL VOC和Cityscapes等语义分割数据集上均取得了当时的SOTA结果。关于ASPP模块的一个简单实现参考如下代码所示,先是分别定义了ASPP的卷积和池化方法,然后在其基础上定义了ASPP模块。


### 定义ASPP卷积方法
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation,
                      dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        super(ASPPConv, self).__init__(*modules)

  ### 定义ASPP池化方法
class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode='bilinear',
 align_corners=False)

### 定义ASPP模块
class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)))

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),)

  # ASPP前向计算流程
    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

下图是Deeplab v2在Cityscapes数据集上的分割效果:

Deeplab v1

DeepLabv1:
Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs

ICLR 2015

在语义分割发展早期,一些研究观点认为将CNN用于图像分割主要存在两个问题:一个是下采样导致的信息丢失问题,另一个则是CNN的空间不变性问题,这与CNN本身的特性有关,这种空间不变性有利于图像分类但却不利于图像分割中的像素定位。从多尺度和上下文信息的角度来看,这两个问题是导致FCN分割效果有限的重要原因。因而,相关研究针对上述两个问题提出了Deeplab v1网络,通过在常规卷积中引入空洞(Atrous)和对CNN分割结果补充CRF作为后处理来优化分割效果。提出Deeplab v1的论文为Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs,是Deeplab系列的开篇之作。

针对第一个问题,池化下采样操作引起信息丢失,Deeplab v1给出的解决方案算是另辟蹊径。常规卷积中,使用池化下采样的主要目的是增大每个像素的感受野,但在Deeplab v1中,作者们的想法是可以不用池化也可以增大像素的感受野,尝试在卷积操作本身上重新进行设计。在Deeplab v1,一种在常规卷积核中插入空洞的设计被提出,相较于池化下采样,空洞卷积能够在不降低图像分辨率的情况下扩大像素感受野,从而就避免了信息损失的问题

空洞卷积(Dilated/Atrous Convolution)也叫扩张卷积或者膨胀卷积,字面意思上来说就是在卷积核中插入空洞,起到扩大感受野的作用。空洞卷积的直接做法是在常规卷积核中填充0,用来扩大感受野,且进行计算时,空洞卷积中实际只有非零的元素起了作用。假设以一个变量a来衡量空洞卷积的扩张系数,则加入空洞之后的实际卷积核尺寸与原始卷积核尺寸之间的关系:

K=k+(k-1)(a-1)

其中为k原始卷积核大小,a为空洞率(Dilation Rate),K为经过扩展后实际卷积核大小。除此之外,空洞卷积的卷积方式跟常规卷积一样。当a=1时,空洞卷积就退化为常规卷积。a=1,2,4时,空洞卷积示意图如下图所示。

对于语义分割而言,空洞卷积主要有三个作用:

第一是扩大感受野,具体前面已经说的比较多了,这里不做重复。但需要明确一点,池化也可以扩大感受野,但空间分辨率降低了,相比之下,空洞卷积可以在扩大感受野的同时不丢失分辨率,且保持像素的相对空间位置不变。简单而言就是空洞卷积可以同时控制感受野和分辨率。

第二就是获取多尺度上下文信息。当多个带有不同空洞率的空洞卷积核叠加时,不同的感受野会带来多尺度信息,这对于分割任务是非常重要的。

第三就是可以降低计算量,不需要引入额外的参数,如图4-13所示,实际卷积时只有带有红点的元素真正进行计算。

针对第二个问题,Deeplab v1通过引入全连接的CRF来对CNN的粗分割结果进行优化。CRF作为一种经典的概率图模型,可用于图像像素之间的关系描述,在传统图像处理中主要用于图像平滑处理。但对于CNN分割问题来说,使用短程的CRFs可能会于事无补,因为分割问题的目标是恢复图像的局部细节信息,而不是对图像做平滑处理。所以Deeplab v1提出的解决方案叫做全连接CRF(Fully Connected CRF)。

条件随机场可以优化物体的边界,平滑带噪声的分割结果,去掉物体中间的预测的孔洞,使得分割结果更加准确。

CRF是一种经典的概率图模型,简单而言就是给定一组输入序列的条件下,求另一组输出序列的条件概率分布模型,CRF在自然语言处理领域有着广泛应用。CRF在语义分割后处理中用法的基本思路如下:对于FCN或者其他分割网络的粗粒度分割结果而言,每个像素点i具有对应的类别标签x_i和观测值y_i,以每个像素为节点,以像素与像素之间的关系作为边即可构建一个CRF模型。在这个CRF模型中,我们通过观测变量y_i来预测像素i对应的标签值x_i。CRF用于像素预测的结构如下图所示。

全连接CRF使用的能量函数为:

Deeplab v1模型流程如下图所示。输入图像经过深度卷积网络(DCNN)后生成浓缩的、粗粒度的语义特征图,再经过双线性插值上采样后形成粗分割结果,最后经全连接的CRF后处理生成最终的分割结果图。

下图是Deeplab v1在无CRF和有CRF后处理的分割效果图:

总的来看,Deeplab v1有如下几个有点:

(1)速度快。基于空洞卷积的CNN分割网络,能够保证8秒每帧(Frame Per Second,FPS)的推理速度,后处理的全连接CRF也仅需要0.5秒。

(2)精度高。Deeplab v1在当时取得了在PASCAL VOC数据集上SOTA的分割模型表现,准确率超过此前SOTA的7.2%。

(3)简易性。DCNN和CRF均为成熟的算法模块,将其进行简单的级联即可在当前的语义分割模型上取得好的效果。

Deeplab v1 PyTorch参考实现代码如下:

https://github.com/wangleihitcs/DeepLab-V1-PyTorch

PSPNet:Pyramid Attention Networkfor Semantic Segmentation

原文地址:https://arxiv.org/pdf/1612.01105.pdf
论文代码:https://github.com/hszhao/PSPNet

解读:
本文的主要创新点在于提出了空间金字塔池化模块,简单来说就是在encoder之后得到feature map X,使用不同尺寸的kernel进行池化(avepooling)操作,之后对得到的feature map进行上采样,使得尺寸和X大小一样,然后进行级联,进而经过卷积操作得到map。和经典的FCN的区别是在encoder和decoder之间加入了PSP模块。

场景解析(scene parsing)是语义分割的一个重要应用方向,区别于一般的语义分割任务,场景解析需要在复杂的自然图像场景下对更庞大的物体类别的每一个像素进行分类,场景解析在自动驾驶和机器人感知等方向应用广泛。但由于自然场景的复杂性、语义标签的多样性以及目标物体的多变性,对于场景解析问题的研究一直存在诸多困难。

场景解析一般基于FCN和空洞卷积网络来进行结构设计,后续的改进方案主要有两个方向,一种是多尺度特征集成,学界普遍的观点认为深层特征能够提取图像的语义信息但缺乏定位信息,将多尺度的特征集成起来能够显著提升模型效果。另一个则是结合图模型的结构预测,比如说使用CRF进行后处理来优化图像分割结果,这些改进方案虽然在一定程度上都能缓解像素定位问题,但对于更为复杂的图像场景仍然效果有限。将全局信息加入到语义分割网络中的设计也有很多研究提到,比如ParseNet,通过给FCN补充全局平均池化来提升分割效果。但这种全局信息不足以在一些复杂的场景解析数据集上体现效果,比如包括150个语义类别的ADE20K数据集。总结来看,FCN对于复杂的场景解析存在如下三个问题:

(1)像素上下文关系不匹配。在语义分割中,图像上下文信息非常重要,不同的像素类别之间往往存在着共现的视觉模式,比如说飞机通常只会出现在跑道上或者空中,而不会出现在马路上,所以飞机和跑道以及天空这三个对象之间存在着共现模式。

(2)语义标签的混淆。ADE20K数据集中有大量语义相近的类别,比如高山与丘陵、建筑与摩天大楼等。FCN在这种容易混淆的类别上很难做出准确的判断。

(3)小目标类别预测困难。在场景解析中,目标物体可以是任意大小,所以有时候对于小目标物体预测效果就很差,比如说对路灯和广告牌的识别,FCN表现就非常不好。

针对上述问题,相关学者基于空间金字塔池化(Spatial Pyramid Pooling, SPP)提出了一种更为广泛的、基于不同图像区域的全局上下文信息集成结构:PSPNet,提出PSPNet的论文为Pyramid Scene Parsing Network,是多尺度上下文结构方向的一个重要的网络设计。PSPNet网络结构如图1所示。

可以看到,将输入图像(a)经过CNN网络提取之后的特征图(b)送入到一个金字塔池化模块(Pyramid Pooling Module,PPM),在该模块中(c),使用四种不同尺度的全局池化(1*1、2*2、3*3和6*6)进行下采样并结合1*1卷积来生成降维后的上下文信息表征。然后直接使用双线性插值对低分辨率的特征图进行上采样,并与阶段(b)的特征图进行连接,形成最终的全局金字塔池化特征,最后再进行一组卷积后形成最终的语义分割预测结果(d)。基于上述结构,PSPNet对于场景解析任务能够提供有效的全局上下文先验信息,PPM相较于直接的全局池化,能够收集多尺度的信息表征,并且在计算开销上相比于FCN也增加不多。此外PSPNet训练时为防止梯度消失还添加了辅助损失函数作为一种深监督机制。

 PSPNet在ADE20K、PASCAL VOC 2012以及Cityscapes数据集上分别进行了测试,均能达到当时SOTA的分割水平。另外也在辅助损失函数和预训练模型等方面做了充分的消融实验(ablation study)。PSPNet在ADE20K数据集上的一组测试效果如图2所示。在三张测试图像上,FCN不能够捕捉图像上下文关系(将第一张图像中水面上的船只识别为汽车),但PSPNet均能够得到很好的分割效果。

网络结构

  • encoder。使用了预训练的ResNet,里面使用了孔洞卷积(后面几层没有下采样,全部使用空洞卷积)。最后输出的feature map是原图的1/8。类似下图

在encoder之后,使用了金字塔池化模块,使用了四种尺寸的金字塔,池化所用的kerne分别1×1, 2×2, 3×3 and 6×6。池化之后上采样,然后将得到的feature map,包括池化之前的做一个级联(concatenate),后面接一个卷积层得到最终的预测图像。作者提到,和全局池化相比,金字塔池化能提取多尺度的信息。

下述代码给出了PSPNet网络结构的一个简单实现流程。


### 定义PPM模块类
class _PyramidPoolingModule(nn.Module):
    def __init__(self, in_dim, reduction_dim, setting):
        super(_PyramidPoolingModule, self).__init__()
        self.features = []
        for s in setting:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(s),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim, momentum=.95),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

  # PPM前向计算流程
    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
        out = torch.cat(out, 1)
        return out

  ### 定义PSPNet类
class PSPNet(nn.Module):
    def __init__(self, num_classes, use_aux=True):
        super(PSPNet, self).__init__()
        self.use_aux = use_aux
    # 使用ResNet101作为预训练模型
        resnet = models.resnet101()
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
    
        self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
        self.final = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )
    # 深监督辅助损失
        if use_aux:
            self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
            initialize_weights(self.aux_logits)
        initialize_weights(self.ppm, self.final)

  ### PSPNet前向计算流程
    def forward(self, x):
        x_size = x.size()
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.training and self.use_aux:
            aux = self.aux_logits(x)
        x = self.layer4(x)
        x = self.ppm(x)
        x = self.final(x)
        if self.training and self.use_aux:
            return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
        return F.upsample(x, x_size[2:], mode='bilinear')

Unet++

代码:https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py

论文:https://arxiv.org/abs/1807.10165

论文作者的知乎讲解:https://zhuanlan.zhihu.com/p/44958351

UNet的编解码结构一经提出以来,大有统一深度学习图像分割之势,后续基于UNet的改进方案也经久不衰,一些研究者也在从网络结构本身来思考UNet的有效性。比如说编解码网络应该取几层,跳跃连接是否能够有更多的变化以及什么样的结构训练起来更加有效等问题。UNet本身是针对医学图像分割任务而提出来的网络结构,该任务不像自然图像分割,对分割精度要求并不是十分严格。但对于医学图像而言,器官和病灶的分割则要求极高的精确性,因为很多时候分割效果的好坏直接关系到对应的临床诊断决策。出于上述两个方面的动机,即设计更好的UNet结构和提升医学图像分割的精度,相关研究者提出了一种嵌套的UNet结构(Nested UNet),也叫UNet++,提出UNet++的论文为UNet++: A Nested U-Net Architecture for Medical Image Segmentation,发表于2018年的医学图像计算和计算机辅助干预(Medical Image Computing and Computer Assisted Intervention,MICCAI)会议上。

UNet++取名为嵌套的UNet,就在于其整体编解码网络结构中还嵌套了编解码的子网络(sub-networks),在此基础上重新设计UNet中间的跳跃连接,并补充了深监督机制加速网络训练收敛。完整的UNet++结构如下图所示。

图中黑色部分为原始的UNet结构,包括编码器下采样、解码器上采样和黑色虚线的跳跃连接三个部分;绿色部分即嵌套的UNet子网络,包括卷积和上采样两部分,而蓝色虚线部分就是UNet++重新设计后的跳跃连接,这部分跟DenseNet的密集连接类似,这里是为子网络提供跳跃连接;最上面红黑连线则是UNet++补充的深监督机制,目的是为了网络能够顺利得到训练。

下面我们从结构设计的角度来对UNet++进行解读。关于UNet结构,最首要的问题就是网络应该有几层,原始的UNet结构用了4层下采样和4层上采样,那么是不是4层就足以满足所有的分割任务需要?答案是否定的。通过本节之前的网络结构分析,我们已经知道,浅层网络能够提取图像粗粒度特征,获取图像基本形态;深层网络能够提取图像的抽象特征,获取图像语义信息,总之浅有浅的侧重,深有深的好处。同之前RefineNet的观点一样,UNet++的作者认为,不管是浅层、深层还是中层,所有层次的特征对于最后的分割都是重要的。有的数据分割任务简单,图像信息单一,可能浅层网络就足以达到很好的效果,而有的数据任务复杂,图像信息丰富,可能需要更深层的网络结构才能达到不错的效果,之前的UNet结构设计很难同时照顾到这种普适性。而UNet++通过设计不同深度的嵌套UNet子网络来实现这种普适性,所以UNet的深度到这里就解决了。

第二个问题则是加入不同深度的嵌套网络后,跳跃连接部分该如何调整。在UNet中,跳跃连接由同层编码器直连到编码器上采样对应层。但加入嵌套子网络后,UNet中原先的长连接就不复存在了,取而代之的是各子网络中的短连接。UNet++的作者们认为,长连接在UNet中是有必要的,能够将图像中前后信息联系起来,对于下采样造成的信息损失有很好的补充作用。所以,UNet++又参考DenseNet的密集连接设计,给嵌套网络补充了长连接,如上图所示。

但是这样又带来了第三个问题:反向传播的时候中间部分可能会收不到由损失函数反传回来的梯度。所以见招拆招,UNet++又通过深监督的方法来强行加梯度,帮助网络正常进行训练。但深监督对于UNet++的好处绝不仅仅限于此,通过不同深监督损失函数,UNet++可以通过网络剪枝来实现可伸缩性。所以,总结来说UNet++相较于原始的UNet,有如下两个优势:

(1)通过嵌套子网络和长短连接来整合不同层次的图像特征,使得网络分割精度更高;

(2)灵活的网络结构配合深监督机制,让参数量巨大的深度网络在可接受的精度范围内能够大幅度的缩减参数量。

UNet++与UNet等网络分割效果对比如下图所示。

UNet++也进一步壮大了UNet家族网络,后续基于其的改进版本也有很多,比如Attention UNet++、UNet 3+等。下述代码给出了UNet++的一个实现参考。


class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output

最后,给出作者在知乎上的一段总结:

回顾一下这次分享,我问了好多问题,也提供了其中一些我个人给出的解释。

一顿听下来,热心的网友可能会懵了,就这么个分割网络,都能说这么久,要我说就放个结构图,说这个网络很牛逼,再告知一下代码在哪儿,谢谢大家就完事儿了。

其实搞学术做研究不是这样子的,UNet++肯定马上就会被更强的结构所代替,但是要设计出更强的结构,你得首先明白这个结构,甚至它的原型U-Net设计背后的心路历程。与其和大家分享一个苍白的分割网络,我更愿意分享的是这个项目背后从开始认识U-Net,到分析它的组成,到批判性的解读,再到改进思路的形成,实验设计,像刚刚分享过程中一次次尴尬的自问自答,中间那些非常饱满的心路历程。这也是我在博士的两年中学到的做研究的范式。

说句题外话,就跟玩狼人杀一样,你一上来就说自己预言家,验了谁谁谁,那没人信的呀,你得说清楚为什么要验他,警徽怎么留,把这些心路历程都盘盘清楚,身份才能做实。

我在微博上也看到有人说,也想到了非常类似的结构,实验也快做完了,看到了我这篇论文心就凉了。其实网络结构怎么样真的不重要,重要的是你怎么能把故事给讲清楚,要是讲完以后还能够引起更多的思考和讨论那就更好了。我在分享中提到了很多我们论文中的不足之处,也非常欢迎大家可以批评指正。

nnUnet

github: https://github.com/MIC-DKFZ/nnUNet

论文地址: https://arxiv.org/abs/1809.10486

相较于常规的自然图像,以UNet为代表的编解码网络在医学图像分割中应用更为广泛。常见的各类医学成像方式,包括计算机断层扫描(Computed Tomography, CT)、核磁共振成像(Magnetic Resonance Imaging, MRI)、超声成像(Ultrasound Imaging)、X光成像(X-ray Imaging)和光学相干断层扫描(Optical Coherence Tomography, OCT)等。对于临床而言,影像学的检查是一项非常重要的诊断方式。在各类模态的影像检查中,精准地对各种器官和病灶进行分割是影像分析的关键步骤,目前深度学习图像分割在各类影像检测和分割中大放异彩。比如基于胸部CT的肺结节检测、基于颅内MR影像的脑胶质瘤分割、基于心脏CT的左心室分割、基于甲状腺超声的结节检测和基于X光的胸片肺部器官分割等。

虽然基于UNet的系列编解码分割网络在各类医学图像分割上取得了长足的进展,并且部分基于相关模型的应用设计已经广泛用于临床分析中。但医学影像本身的复杂性和差异性也极度影响着分割模型的泛化性和通用性,主要体现在以下几个方面:

(1)各类模态的医学影像之间差异大,如研究队列的大小、图像尺寸和维度、分辨率和体素(voxel)强度等。

(2)分割的语义标签的极度不平衡。相较于影像中的正常组织,病变区域一般都只占极少部分,这就造成了正常组织的体素标签与病灶组织的体素标签之间极度的类不平常。

(3)不同影像数据之间的专家标注差异大,并且一些图像的标注结果会存在模棱两可的情况。

(4)一些数据集在图像几何和形状等属性上差异明显,切片不对齐和各向异性的问题也非常严重。

提出一种鲁棒的基于2D UNet和3D UNet的自适应框架nnUMet。作者在各种任务上拿这个框架和目前的STOA方法进行了比较,且该方法不需要手动调参。最终nnUNet得到了最高的平均dice。

当前的医学图像分割被CNN的方法主导,但是在不同的任务上需要不同的结构和不同的调参策略才达到了各自任务的最佳,这些在某个任务上拿到第一的方法,在其他任务上却不行。

The Medical Segmentation Decathlon计划通过这种方式解决这个问题:希望参赛者设计一种算法,在10种数据集上进行测试,都能够达到很好的效果,而算法不能够针对某种数据集进行人为的调整,只能自动的去适应。

比赛分为两个阶段:(1)开发阶段参与者拿到7个数据集用于优化算法;(2)冻结代码后公开剩余的3个数据集,用于评估。

作者认为过多的人为调整网络结构,会导致对于特定数据集的过拟合。非网络结构方面的影响可能对于分割任务影响更大。

作者提出一种nnUNet(no-new-Net)框架,基于原始的UNet(很小的修改),不去采用哪些新的结构,如相残差连接、dense连接、注意力机制等花里胡哨的东西。相反的,把重心放在:预处理(resampling和normalization)、训练(loss,optimizer设置、数据增广)、推理(patch-based策略、test-time-augmentations集成和模型集成等)、后处理(如增强单连通域等)。

  • 网络结构

基础版UNet:2D UNet,3D UNet,UNet级联(第一级对下采样低分辨率图像进行粗分割,第二级结合第一级的结果进行微调,两级都用3DUNet)

微小修改:

(1)ReLU换 leaky ReLU(neg.slope 1e-2);

(2)Batch Norm换Instance Norm

图1 UNet Cascade. Stage1:下采样数据上进行粗分割;Stage2:分割结果和原图concat送入第二个网络进行refine

网络拓扑自适应:输入图像尺寸会有不同,而硬件的资源是有限的,因此需要在网络容量和Batch-size上做到权衡。

默认参数设置:

2D UNet:crop-size<=256×256(中值尺寸小于256时,采用中值尺寸); batch-size<=42; base-channel=30; pooling to size>=8; pooling_num<6
3D UNet: crop-size<=128x128x128(中值尺寸小于128时,采用中值尺寸); batch-size>=2; base_channel=30; pooling to size>=8; poolingnum<6

  • 预处理

整体数据Crop:只在非零区域内crop,减少计算消耗

Resample:数据集中存在不同spacing的数据,默认自动归一化到数据集所有数据spacing的中值spacing。原始数据使用三阶spline插值;Mask使用最邻近插值。

UNet Cascade采用特殊的Resample策略:中值尺寸大于显存限制下可处理尺寸的4倍时(batch-size=2),采用级联策略,对数据进行下采样(采样2的倍数,直到满足前面的要求);如果数据分辨率三个轴方向不相等,先降采样高分辨率轴使得三轴相等,再三轴同时降采样直到满足上述要求。

Normalization

CT:通过统计整个数据集中mask内像素的HU值范围,clip出[0.05,99.5]百分比范围的HU值范围,然后使用z-score方法进行归一化;

MR:对每个患者数据单独执行z-score归一化。

如果crop导致数据集的平均尺寸减小到1/4甚至更小,则只在mask内执行标准化,mask设置为0.

  • 训练过程

从头训练,使用五折交叉验证,loss函数:结合dice loss和交叉熵loss:

对于在全训练集上训练的3D-UNet(UNet Cascade的第一阶段和非级联的3D UNet,不包括UNet Cascade的第二阶段),对每一个样本单独计算dice loss,然后在batch上去平均。对其他的网络(2D UNet和UNet Cascade的第二阶段),将一个batch内的所有样本当做一个整体的样本计算整个batch上的dice(防止当crop后出现局部区域内不存在某一类时单独计算该类loss导致分母为零的情况,这也要保证batch-size不能太小)。

dice loss形式如下:

其中u为概率输出(softmax output),v为硬编码(one hot encoding)的ground truth。K为多分类类别数。

其他训练参数

Adam优化器,学习率3e-4;250个batch/epoch;
学习率调整策略:计算训练集和验证集的指数移动平均loss,如果训练集的指数移动平均loss在30个epoch内减少不够5e-3,则学习率衰减5倍;
训练停止条件:当验证集指数移动平均loss在60个epoch内减少不够5e-3,或者学习率小于1e-6,则停止训练。

数据增广:随机旋转、随机缩放、随机弹性变换、伽马校正、镜像。

注意:

1.如果3D UNet的输入patch的尺寸的最大边长是最短边长的两倍以上,那么应用三维数据扩充可能是次优的。这种情况下可以使用2D的数据增广。
2.UNet Cascade的stage 2接收前一阶段的输出作为输入的一部分,为了防止强co-adaptation,我们可以应用随机形态学操作(erode、dilate、open、close),随机的去除掉一些分割结果的连通域。

patch采样:为了增加网络的稳定性,patch采样的时候会保证一个batch的样本中有超过1/3的像素是前景类的像素。这个很关键,否则你的前景dice会收敛的很慢。

  • 推理(Inference)

所有的推理都是基于patch的。

patch的边界上精度会有损,因此在对patch重叠处的像素进行fuse时,边界的像素权重低,中心的像素权重高;patch重叠的stride为size/2;使用test-data-augmentation(增广方式:绕各个轴的镜像增广);使用了5个训练的模型集成进行推理(5个模型是通过5折交叉验证产生的5个模型)

  • 后处理

主要就是使用连通域分析

Swin-Unet:Unet形状的纯Transformer的医学图像分割

首个基于纯Transformer的U-Net形的医学图像分割网络,其中利用Swin Transformer构建编码器、bottleneck和解码器,表现SOTA!性能优于TransUnet、Att-UNet等

单位:慕尼黑工业大学, 复旦大学, 华为(田奇等人)
代码:https://github.com/HuCaoFighting/Swin-Unet
论文下载链接:https://arxiv.org/abs/2105.0553

在过去的几年中,卷积神经网络(CNN)在医学图像分析中取得了里程碑式的进展。尤其是,基于U形架构和跳跃连接的深度神经网络已广泛应用于各种医学图像任务中。但是,尽管CNN取得了出色的性能,但是由于卷积操作的局限性,它无法很好地学习全局和远程语义信息交互。

在本文中,我们提出了Swin-Unet,它是用于医学图像分割的类似Unet的纯Transformer。标记化的图像块通过跳跃连接被馈送到基于Transformer的U形En-Decoder架构中,以进行局部全局语义特征学习。

具体来说,我们使用带有偏移窗口的分层Swin Transformer作为编码器来提取上下文特征。

实验结果

在对输入和输出进行4倍的直接下采样和上采样的情况下,对多器官和心脏分割任务进行的实验表明,基于纯Transformer的U形编码器/解码器网络优于那些全卷积或者Transformer和卷积的组合。

个人感觉,这个transformer架构下的参数量应该会大很多吧(近百M)?目前作者好像并未在论文中给出对比。另外就是这个 感觉没啥创新点..感觉没啥有点水。

UNet网络

论文: https://arxiv.org/abs/1505.04597

FCN虽然做出了开创性的工作,FCN-8s相较于此前的SOTA分割表现,已经取得了巨大的优势。但从分割效果上看还很粗糙,对图像的细节处理还很不成熟,也没有考虑到像素与像素之间的上下文(context)关系,所以FCN更像是一项抛砖引玉式的工作,随着U形的编解码结构成为通用的语义分割网络设计范式,各种网络如雨后春笋般涌现。UNet是U形网络结构最经典和最主要的代表网络,因其网络结构是一个U形而得名,这类编解码的结构也因而被称之为U形结构。提出UNet的论文为U-Net: Convolutional Networks for Biomedical Image Segmentation,与FCN提出时间相差了两个月,其结构设计在FCN基础上做了进一步的改进,设计初衷主要是用于医学图像的分割。截至到本书写稿,UNet在谷歌学术上的引用次数已达44772次,堪称深度学习语义分割领域的里程碑式的工作。

1、与FCN区别

U-Net和FCN非常的相似,U-Net比FCN稍晚提出来,但都发表在2015年,和FCN相比,U-Net的第一个特点是完全对称,也就是左边和右边是很类似的,而FCN的decoder相对简单,只用了一个deconvolution的操作,之后并没有跟上卷积结构。第二个区别就是skip connection,FCN用的是加操作(summation),U-Net用的是叠操作(concatenation)。这些都是细节,重点是它们的结构用了一个比较经典的思路,也就是编码和解码(encoder-decoder),早在2006年就被Hinton大神提出来发表在了nature上.

当时这个结构提出的主要作用并不是分割,而是压缩图像和去噪声。输入是一幅图,经过下采样的编码,得到一串比原先图像更小的特征,相当于压缩,然后再经过一个解码,理想状况就是能还原到原来的图像。这样的话我们存一幅图的时候就只需要存一个特征和一个解码器即可。这个想法我个人认为是很漂亮了。同理,这个思路也可以用在原图像去噪,做法就是在训练的阶段在原图人为的加上噪声,然后放到这个编码解码器中,目标是可以还原得到原图。

后来把这个思路被用在了图像分割的问题上,也就是现在我们看到的U-Net结构,在它被提出的三年中,有很多很多的论文去讲如何改进U-Net或者FCN,不过这个分割网络的本质的拓扑结构是没有改动的。举例来说,去年ICCV上凯明大神提出的Mask RCNN. 相当于一个检测,分类,分割的集大成者,我们仔细去看它的分割部分,其实使用的也就是这个简单的FCN结构。说明了这种“U形”的编码解码结构确实非常的简洁,并且最关键的一点是好用。

2、为什么有效

相比于FCN和Deeplab等,UNet共进行了4次上采样,并在同一个stage使用了skip connection,而不是直接在高级语义特征上进行监督和loss反传,这样就保证了最后恢复出来的特征图融合了更多的low-level的feature,也使得不同scale的feature得到了的融合,从而可以进行多尺度预测和DeepSupervision。4次上采样也使得分割图恢复边缘等信息更加精细。

其次我们聊聊【医疗影像】,医疗影像有什么样的特点呢(尤其是相对于自然影像而言)?

1.图像语义较为简单、结构较为固定。我们做脑的,就用脑CT和脑MRI,做胸片的只用胸片CT,做眼底的只用眼底OCT,都是一个固定的器官的成像,而不是全身的。由于器官本身结构固定和语义信息没有特别丰富,所以高级语义信息和低级特征都显得很重要(UNet的skip connection和U型结构就派上了用场)。

2.数据量少。医学影像的数据获取相对难一些,很多比赛只提供不到100例数据。所以我们设计的模型不宜多大,参数过多,很容易导致过拟合。

原始UNet的参数量在28M左右(上采样带转置卷积的UNet参数量在31M左右),而如果把channel数成倍缩小,模型可以更小。缩小两倍后,UNet参数量在7.75M。缩小四倍,可以把模型参数量缩小至2M以内,非常轻量。个人尝试过使用Deeplab v3+和DRN等自然图像语义分割的SOTA网络在自己的项目上,发现效果和UNet差不多,但是参数量会大很多。

为什么适用于医学图像?

(1)因为医学图像边界模糊、梯度复杂,需要较多的高分辨率信息。高分辨率用于精准分割。

(2)人体内部结构相对固定,分割目标在人体图像中的分布很具有规律,语义简单明确,低分辨率信息能够提供这一信息,用于目标物体的识别。

UNet结合了低分辨率信息(提供物体类别识别依据)和高分辨率信息(提供精准分割定位依据),完美适用于医学图像分割。

网络结构

在医学图像领域,具体到更加细分的医学图像识别任务时,大量的带有高质量标注的图像数据十分难得,在此之前的通常做法是采用滑动窗口卷积(类似于图像分块)的方式来进行图像局部预测,这么做的好处是可以做图像像素做到一定程度定位,其次就是滑窗分块能够使得训练样本量增多。但缺点也很明显,一个是滑窗操作非常耗时,推理的时候效率低下,其次就是不能兼顾定位精度和像素上下文信息的利用率。UNet在FCN的基础上,完整地给出了U形的编解码结构,如下图所示

输入是一幅图,输出是目标的分割结果。继续简化就是,一幅图,编码,或者说降采样,然后解码,也就是升采样,然后输出一个分割结果。根据结果和真实分割的差异,反向传播来训练这个分割网络。我们可以说,U-Net里面最精彩的部分就是这三部分:

  • 下采样
  • 上采样
  • skip connection

UNet结构包括编码器下采样、解码器上采样和同层跳跃连接三个组成部分。编码器由4组卷积、ReLU激活和最大池化构成,每一组均有两次3*3的卷积,每个卷积层后面都有一次ReLU激活函数,然后再进行一次步长为2的2*2最大池化进行下采样,如第一组操作输入图像大小为572*572,两轮3*3的卷积之后的特征图大小为568*568,再经过22最大池化后的输出尺寸为284*284。解码器由4组2*2转置卷积、3*3卷积构成和一个ReLU激活函数构成,在最后的输出层又补充了一个1*1卷积。最后是同层跳跃连接,这也是UNet的特色操作之一,指的是将下采样时每一层的输出裁剪后连接到同层的上采样层做融合。每一次下采样都会有一个跳跃连接与对应的上采样进行融合,这种不同尺度的特征融合对上采样恢复像素大有帮助,具体来说就是高层(浅层)下采样倍数小,特征图具备更加细致的图特征,低层(深层)下采样倍数大,信息经过大量浓缩,空间损失大,但有助于目标区域(分类)判断,当高层和低层的特征进行融合时,分割效果往往会非常好。从某种程度上讲,这种跳跃连接也可以视为一种深度监督。

我们将UNet结构按照编码器、解码器和同层跳跃连接进行简化,如下图所示。编码器下采样用于特征提取和语义信息浓缩,解码器上采样用于图像像素恢复,跳跃连接则用于信息补充。自此,基于U形结构的编解码设计成为深度学习语义分割中的奠基性的网络结构,经过近几年的发展,语义分割虽然取得了长足的进步,但UNet和编解码结构一直是新的模型设计的参照对象。

代码实现:

# 导入PyTorch相关模块
import torch
import torch.nn as nn
import torch.nn.functional as F

### 编码块
class UNetEnc(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False):
        super().__init__()
    # 每一个编码块中的结构
        layers = [
            nn.Conv2d(in_channels, out_channels, 3, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, dilation=2),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers += [nn.Dropout(.5)]
        layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]
        self.down = nn.Sequential(*layers)
  # 编码块前向计算流程
    def forward(self, x):
        return self.down(x)

### 解码块    
class UNetDec(nn.Module):
    def __init__(self, in_channels, features, out_channels):
        super().__init__()
    # 每一个解码块中的结构
        self.up = nn.Sequential(
            nn.Conv2d(in_channels, features, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, 3),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(features, out_channels, 2, stride=2),
            nn.ReLU(inplace=True),
        )
  # 解码块前向计算流程
    def forward(self, x):
        return self.up(x)

### 基于编解码的U-Net
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
    # 四个编码块
        self.enc1 = UNetEnc(3, 64)
        self.enc2 = UNetEnc(64, 128)
        self.enc3 = UNetEnc(128, 256)
        self.enc4 = UNetEnc(256, 512, dropout=True)
    # 中间部分(U形底部)
        self.center = nn.Sequential(
            nn.Conv2d(512, 1024, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, 3),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.ConvTranspose2d(1024, 512, 2, stride=2),
            nn.ReLU(inplace=True),
        )
    # 四个解码块
        self.dec4 = UNetDec(1024, 512, 256)
        self.dec3 = UNetDec(512, 256, 128)
        self.dec2 = UNetDec(256, 128, 64)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(inplace=True),
        )
        self.final = nn.Conv2d(64, num_classes, 1)

    # 前向传播过程
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        center = self.center(enc4)
        # 包含了同层分辨率级联的解码块
        dec4 = self.dec4(torch.cat([
            center, F.upsample_bilinear(enc4, center.size()[2:])], 1))
        dec3 = self.dec3(torch.cat([
            dec4, F.upsample_bilinear(enc3, dec4.size()[2:])], 1))
        dec2 = self.dec2(torch.cat([
            dec3, F.upsample_bilinear(enc2, dec3.size()[2:])], 1))
        dec1 = self.dec1(torch.cat([
            dec2, F.upsample_bilinear(enc1, dec2.size()[2:])], 1))
        return F.upsample_bilinear(self.final(dec1), x.size()[2:])