Attention UNet

论文: https://arxiv.org/abs/1804.03999

以CNN为基础的编解码结构在图像分割上展现出了卓越的效果,尤其是医学图像的自动分割上。但一些研究认为以往的FCN和UNet等分割网络存在计算资源和模型参数的过度和重复使用,例如相似的低层次特征被级联内的所有网络重复提取。针对这类普遍性的问题,相关研究提出了给UNet添加注意力门控(Attention Gates, AGs)的方法,形成一个新的图像分割网络结构:Attention UNet。提出Attention UNet的论文为Attention U-Net: Learning Where to Look for the Pancreas,发表在2018年CVPR上。注意力机制原先是在自然语言处理领域被提出并逐渐得到广泛应用的一种新型结构,旨在模仿人的注意力机制,有针对性的聚焦数据中的突出特征,能够使得模型更加高效。

Attention UNet的网络结构如下图所示,需要注意的是,论文中给出的3D版本的卷积网络。其中编码器部分跟UNet编码器基本一致,主要的变化在于解码器部分。其结构简要描述如下:编码器部分,输入图像经过两组3*3*3的3D卷积和ReLU激活,然后再进行最大池化下采样,经过3组这样的卷积-池化块之后,网络进入到解码器部分。编码器最后一层的特征图除了直接进行上采样外,还与来自编码器的特征图进行注意力门控计算,然后再与上采样的特征图进行合并,经过三次这样的上采样块之后即可得到最终的分割输出图。相比于普通UNet的解码器,Attention UNet会将解码器中的特征与编码器连接过来的特征进行注意力门控处理,然后再与上采样进行拼接。经过注意力门控处理后得到的特征图会包含不同空间位置的重要性信息,使得模型能够重点关注某些目标区域。

我们将Attention UNet的注意力门控单独拿出来进行分析,看AGs是如何让模型能够聚焦到目标区域的。如图中上图所示,将Attention UNet网络中的一个上采样块单独拿出来,其中x_l为来自同层编码器的输出特征图,g表示由解码器部分用于上采样的特征图,这里同时也作为注意力门控的门控信号参数与x_l的注意力计算,而x^hat_l即为经过注意力门控计算后的特征图,此时x^hat_l是包含了空间位置重要性信息的特征图,再将其与下一层上采样后的特征图进行合并才得到该上采样块最终的输出。

将x_l和g_i计算得到的注意力系数再次与x_l相乘即可得到x^hat_l,这种经过与注意力系数相乘后的特征图会让图像中不相关的区域值变小,目标区域的值相对会变大,提升网络预测速度同时,也会提高图像的分割精度。论文中的各项实验结果也表明,经过注意力门控加成后后UNet,效果均要优于原始的UNet。下述代码给出了Attention UNet的一个2D参考实现,并且下采样次数由论文中的3次改为了4次。


### 定义Attention UNet类
class Att_UNet(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(Att_UNet, self).__init__()
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)

        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 =
       nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
    
  ### 定义前向传播流程
    def forward(self,x):
        # 编码器部分
        x1 = self.Conv1(x)
        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)
        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)
        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # 解码器+连接部分
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)
        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)
        d1 = self.Conv_1x1(d2)
        return d1
  
  ### 定义Attention门控块
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
    # 注意力门控向量
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int,
            kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
            )
        # 同层编码器特征图向量
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int,
            kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )
    # ReLU激活函数
    self.relu = nn.ReLU(inplace=True)
    # 卷积+BN+sigmoid激活函数
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1,
            kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
    ###  Attention门控的前向计算流程 
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        return x*psi

总结来说,Attention UNet提出了在原始UNet基础添加注意力门控单元,注意力得分能够使得图像分割时聚焦到目标区域,该结构作为一个通用结构可以添加到任何任务类型的神经网络结构中,在语义分割网络中对前景目标区域的像素更具有敏感度。Attention UNet壮大了UNet家族网络,此后基于其的改进版本也层出不穷。

SegNet

论文(2015):SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

Github:https://github.com/alexgkendall/caffe-segnet

  • 把本文提出的架构和FCN、DeepLab-LargeFOV、DeconvNet做了比较,这种比较揭示了在实现良好分割性能的前提下内存使用情况与分割准确性的权衡。
  • SegNet的主要动机是场景理解的应用。因此它在设计的时候考虑了要在预测期间保证内存和计算时间上的效率。
  • 定量的评估表明,SegNet在和其他架构的比较上,时间和内存的使用都比较高效。

SegNet论文提出了max pooling的改进版,使用该pooling操作既可以进行下采样操作,也可以进行上采样操作。在下采样操作中同时输出pooling后的结果和pooling过程中的索引。在上采样操作中,利用下采样对应位置的索引,进行上采样操作,这样的优势在于记住了最亮特征像素的空间位置。(去除了unet里面的反卷积操作)

优点,

  1. 可以提高物体边界的分割效果
  2. 相比反卷积操作,减少了参数数量,减少了运算量,相比resize操作,减少了插值的运算量,而实际增加的索引参数也很少。
  3. 该pooling操作可以应用于任何基于编码-解码的分割模型。

SegNet网络结构如下图所示,是一个编解码完全对称的结构。其编码器直接用了VGG16的结构,并将全连接层全部改为卷积层,实际训练时可使用VGG16的预训练权重进行初始化;编码器将13层卷积层分为5组卷积块,每组卷积块之间用最大池化层进行下采样。作为一个对称结构,SegNet解码器也有13层卷积层,同样分为5组卷积块,每组卷积块之间用双线性插值和最大池化位置索引进行上采样,这也是SegNet最大的特色。

SegNet研究团队认为编码器下采样过程中图像信息损失较多,直接存储所有卷积块的特征图又非常占用内存,因而在SegNet中提出在每一次最大池化下采样前存储最大池化的位置索引(Max-pooling indices),即记住最大池化操作中,最大值在2*2池化窗口中的位置。每个2*2窗口仅需要2 bits内存存储量,这种池化位置索引可用于上采样解码时恢复图像信息。下图给出了SegNet与FCN之间的上采样方法对比。可以观察到,SegNet使用双线性插值并结合最大池化位置索引进行上采样,而FCN则是基于去卷积结合编码器卷积特征图进行上采样。

SegNet这种轻量化的上采样方式,不仅能够提升图像边界分割效果,在端到端的实时分割项目中速度也非常快,并且这种结构设计可以配置到任意的编解码网络中,是一种优秀的分割网络设计方式。下述代码给出了SegNet的一个简易的结构实现,因为SegNet解码器的特殊性,我们单独定义了一个解码器类,编码器部分直接使用VGG16的预训练权重层,然后在编解码器基础上搭建SegNet并定义前向计算流程。


# 导入PyTorch相关模块
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import models

# 定义SegNet解码器类
class SegNetDec(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
        ]
        layers += [
            nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
        ] * num_layers
        layers += [
            nn.Conv2d(in_channels // 2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)

### 定义SegNet类
class SegNet(nn.Module):
    def __init__(self, classes):
        super().__init__()
    # 编码器使用vgg16预训练权重
        vgg16 = models.vgg16(pretrained=True)
        features = vgg16.features
        self.enc1 = features[0: 4]
        self.enc2 = features[5: 9]
        self.enc3 = features[10: 16]
        self.enc4 = features[17: 23]
        self.enc5 = features[24: -1]
    # 编码器卷积层不参与训练
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.requires_grad = False
    
        self.dec5 = SegNetDec(512, 512, 1)
        self.dec4 = SegNetDec(512, 256, 1)
        self.dec3 = SegNetDec(256, 128, 1)
        self.dec2 = SegNetDec(128, 64, 0)

        self.final = nn.Sequential(*[
            nn.Conv2d(64, classes, 3, padding=1),
            nn.BatchNorm2d(classes),
            nn.ReLU(inplace=True)
        ])
  # 定义SegNet前向计算流程
    def forward(self, x):
        x1 = self.enc1(x)
        e1, m1 = F.max_pool2d(x1, kernel_size=2, stride=2,
 return_indices=True)
        x2 = self.enc2(e1)
        e2, m2 = F.max_pool2d(x2, kernel_size=2, stride=2,
 return_indices=True)
        x3 = self.enc3(e2)
        e3, m3 = F.max_pool2d(x3, kernel_size=2, stride=2,
 return_indices=True)
        x4 = self.enc4(e3)
        e4, m4 = F.max_pool2d(x4, kernel_size=2, stride=2,
 return_indices=True)
        x5 = self.enc5(e4)
        e5, m5 = F.max_pool2d(x5, kernel_size=2, stride=2,
 return_indices=True)

        def upsample(d):
            d5 = self.dec5(F.max_unpool2d(d, m5, kernel_size=2,
 stride=2, output_size=x5.size()))
            d4 = self.dec4(F.max_unpool2d(d5, m4, kernel_size=2,
 stride=2, output_size=x4.size()))
            d3 = self.dec3(F.max_unpool2d(d4, m3, kernel_size=2,
 stride=2, output_size=x3.size()))
            d2 = self.dec2(F.max_unpool2d(d3, m2, kernel_size=2,
 stride=2, output_size=x2.size()))
            d1 = F.max_unpool2d(d2, m1, kernel_size=2, stride=2,
 output_size=x1.size())
            return d1

        d = upsample(e5)
        return self.final(d)

图像分割损失函数loss 总结+代码

汇总语义分割中常用的损失函数:

  1. cross entropy loss
  2. weighted loss
  3. focal loss
  4. dice soft loss
  5. soft iou loss
  6. Tversky Loss
  7. Generalized Dice Loss
  8. Boundary Loss
  9. Exponential Logarithmic Loss
  10. Focal Tversky Loss
  11. Sensitivity Specificity Loss
  12. Shape-aware Loss
  13. Hausdorff Distance Loss

参考论文Medical Image Segmentation Using Deep Learning:A Survey

论文地址:A survey of loss functions for semantic segmentation
代码地址https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions
项目推荐https://github.com/JunMa11/SegLoss

图像分割一直是一个活跃的研究领域,因为它有可能修复医疗领域的漏洞,并帮助大众。在过去的5年里,各种论文提出了不同的目标损失函数,用于不同的情况下,如偏差数据,稀疏分割等。

图像分割可以定义为像素级别的分类任务。图像由各种像素组成,这些像素组合在一起定义了图像中的不同元素,因此将这些像素分类为一类元素的方法称为语义图像分割。在设计基于复杂图像分割的深度学习架构时,通常会遇到了一个至关重要的选择,即选择哪个损失/目标函数,因为它们会激发算法的学习过程。损失函数的选择对于任何架构学习正确的目标都是至关重要的,因此自2012年以来,各种研究人员开始设计针对特定领域的损失函数,以为其数据集获得更好的结果。

这些损失函数可大致分为4类:基于分布的损失函数,基于区域的损失函数,基于边界的损失函数和基于复合的损失函数( Distribution-based,Region-based,  Boundary-based,  and  Compounded)

1、cross entropy loss

用于图像语义分割任务的最常用损失函数是像素级别的交叉熵损失,这种损失会逐个检查每个像素,将对每个像素类别的预测结果(概率分布向量)与我们的独热编码标签向量进行比较。

假设我们需要对每个像素的预测类别有5个,则预测的概率分布向量长度为5:

每个像素对应的损失函数为:

整个图像的损失就是对每个像素的损失求平均值。

特别注意的是,binary entropy loss 是针对类别只有两个的情况,简称 bce loss,损失函数公式为:

#二值交叉熵,这里输入要经过sigmoid处理  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
nn.BCELoss(F.sigmoid(input), target)  
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层  
nn.CrossEntropyLoss(input, target)

2、weighted loss

由于交叉熵损失会分别评估每个像素的类别预测,然后对所有像素的损失进行平均,因此我们实质上是在对图像中的每个像素进行平等地学习。如果多个类在图像中的分布不均衡,那么这可能导致训练过程由像素数量多的类所主导,即模型会主要学习数量多的类别样本的特征,并且学习出来的模型会更偏向将像素预测为该类别。

FCN论文和U-Net论文中针对这个问题,对输出概率分布向量中的每个值进行加权,即希望模型更加关注数量较少的样本,以缓解图像中存在的类别不均衡问题。

比如对于二分类,正负样本比例为1: 99,此时模型将所有样本都预测为负样本,那么准确率仍有99%这么高,但其实该模型没有任何使用价值。

为了平衡这个差距,就对正样本和负样本的损失赋予不同的权重,带权重的二分类损失函数公式如下:

要减少假阴性样本的数量,可以增大 pos_weight;要减少假阳性样本的数量,可以减小 pos_weight。

class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):  
   """  
   Network has to have NO NONLINEARITY!  
   """  
   def __init__(self, weight=None):  
       super(WeightedCrossEntropyLoss, self).__init__()  
       self.weight = weight  
  
   def forward(self, inp, target):  
       target = target.long()  
       num_classes = inp.size()[1]  
  
       i0 = 1  
       i1 = 2  
  
       while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once  
           inp = inp.transpose(i0, i1)  
           i0 += 1  
           i1 += 1  
  
       inp = inp.contiguous()  
       inp = inp.view(-1, num_classes)  
  
       target = target.view(-1,)  
       wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)  
  
       return wce_loss(inp, target)

3、focal loss

上面针对不同类别的像素数量不均衡提出了改进方法,但有时还需要将像素分为难学习和容易学习这两种样本。

容易学习的样本模型可以很轻松地将其预测正确,模型只要将大量容易学习的样本分类正确,loss就可以减小很多,从而导致模型不怎么顾及难学习的样本,所以我们要想办法让模型更加关注难学习的样本。

对于较难学习的样本,将 bce loss 修改为:

其中的 γ 通常设置为2。

通过这种修改,就可以使模型更加专注于学习难学习的样本。

而将这个修改和对正负样本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:

class FocalLoss(nn.Module):  
   """  
   copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py  
   This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in  
   'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'  
       Focal_Loss= -1*alpha*(1-pt)*log(pt)  
   :param num_class:  
   :param alpha: (tensor) 3D or 4D the scalar factor for this criterion  
   :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more  
                   focus on hard misclassified example  
   :param smooth: (float,double) smooth value when cross entropy  
   :param balance_index: (int) balance class index, should be specific when alpha is float  
   :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.  
   """  
  
   def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):  
       super(FocalLoss, self).__init__()  
       self.apply_nonlin = apply_nonlin  
       self.alpha = alpha  
       self.gamma = gamma  
       self.balance_index = balance_index  
       self.smooth = smooth  
       self.size_average = size_average  
  
       if self.smooth is not None:  
           if self.smooth < 0 or self.smooth > 1.0:  
               raise ValueError('smooth value should be in [0,1]')  
  
   def forward(self, logit, target):  
       if self.apply_nonlin is not None:  
           logit = self.apply_nonlin(logit)  
       num_class = logit.shape[1]  
  
       if logit.dim() > 2:  
           # N,C,d1,d2 -> N,C,m (m=d1*d2*...)  
           logit = logit.view(logit.size(0), logit.size(1), -1)  
           logit = logit.permute(0, 2, 1).contiguous()  
           logit = logit.view(-1, logit.size(-1))  
       target = torch.squeeze(target, 1)  
       target = target.view(-1, 1)  
       # print(logit.shape, target.shape)  
       #   
       alpha = self.alpha  
  
       if alpha is None:  
           alpha = torch.ones(num_class, 1)  
       elif isinstance(alpha, (list, np.ndarray)):  
           assert len(alpha) == num_class  
           alpha = torch.FloatTensor(alpha).view(num_class, 1)  
           alpha = alpha / alpha.sum()  
       elif isinstance(alpha, float):  
           alpha = torch.ones(num_class, 1)  
           alpha = alpha * (1 - self.alpha)  
           alpha[self.balance_index] = self.alpha  
  
       else:  
           raise TypeError('Not support alpha type')  
         
       if alpha.device != logit.device:  
           alpha = alpha.to(logit.device)  
  
       idx = target.cpu().long()  
  
       one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()  
       one_hot_key = one_hot_key.scatter_(1, idx, 1)  
       if one_hot_key.device != logit.device:  
           one_hot_key = one_hot_key.to(logit.device)  
  
       if self.smooth:  
           one_hot_key = torch.clamp(  
               one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)  
       pt = (one_hot_key * logit).sum(1) + self.smooth  
       logpt = pt.log()  
  
       gamma = self.gamma  
  
       alpha = alpha[idx]  
       alpha = torch.squeeze(alpha)  
       loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt  
  
       if self.size_average:  
           loss = loss.mean()  
       else:  
           loss = loss.sum()  
       return loss

4、dice soft loss

语义分割任务中常用的还有一个基于 Dice 系数的损失函数,该系数实质上是两个样本之间重叠的度量。此度量范围为 0~1,其中 Dice 系数为1表示完全重叠。Dice 系数最初是用于二进制数据的,可以计算为:

|A∩B| 代表集合A和B之间的公共元素,并且 |A| 代表集合A中的元素数量(对于集合B同理)。

对于在预测的分割掩码上评估 Dice 系数,我们可以将 |A∩B| 近似为预测掩码和标签掩码之间的逐元素乘法,然后对结果矩阵求和。

计算 Dice 系数的分子中有一个2,那是因为分母中对两个集合的元素个数求和,两个集合的共同元素被加了两次。 为了设计一个可以最小化的损失函数,可以简单地使用 1−Dice。 这种损失函数被称为 soft Dice loss,这是因为我们直接使用预测出的概率,而不是使用阈值将其转换成一个二进制掩码。

Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。

对于神经网络的输出,分子与我们的预测和标签之间的共同激活有关,而分母分别与每个掩码中的激活数量有关,这具有根据标签掩码的尺寸对损失进行归一化的效果。

对于每个类别的mask,都计算一个 Dice 损失:

将每个类的 Dice 损失求和取平均,得到最后的 Dice soft loss。

下面是代码实现:

def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 
    ''' 
    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
    Assumes the `channels_last` format.
  
    # Arguments
        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) 
        epsilon: Used for numerical stability to avoid divide by zero errors
    
    # References
        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation 
        https://arxiv.org/abs/1606.04797
        More details on Dice loss formulation 
        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
        
        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
    '''
    
    # skip the batch and class axis for calculating Dice score
    axes = tuple(range(1, len(y_pred.shape)-1)) 
    numerator = 2. * np.sum(y_pred * y_true, axes)
    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
    
    return 1 - np.mean(numerator / (denominator + epsilon)) # average over classes and batch

5、soft IoU loss

前面我们知道计算 Dice 系数的公式,其实也可以表示为:

其中 TP 为真阳性样本,FP 为假阳性样本,FN 为假阴性样本。分子和分母中的 TP 样本都加了两次。

IoU 的计算公式和这个很像,区别就是 TP 只计算一次:

和 Dice soft loss 一样,通过 IoU 计算损失也是使用预测的概率值:

其中 C 表示总的类别数。

6、Tversky Loss

论文地址为:https://arxiv.org/pdf/1706.05… 

医学影像中存在很多的数据不平衡现象,使用不平衡数据进行训练会导致严重偏向高精度但低召回率(sensitivity)的预测,这是不希望的,特别是在医学应用中,假阴性比假阳性更难容忍。本文提出了一种基于Tversky指数的广义损失函数,解决了三维全卷积深神经网络训练中数据不平衡的问题,在精度和召回率之间取得了较好的折衷。

Dice loss的正则化版本,以控制假阳性和假阴性对损失函数的贡献,TL被定义为

class TverskyLoss(nn.Module):  
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,  
                square=False):  
       """  
       paper: https://arxiv.org/pdf/1706.05721.pdf  
       """  
       super(TverskyLoss, self).__init__()  
  
       self.square = square  
       self.do_bg = do_bg  
       self.batch_dice = batch_dice  
       self.apply_nonlin = apply_nonlin  
       self.smooth = smooth  
       self.alpha = 0.3  
       self.beta = 0.7  
  
   def forward(self, x, y, loss_mask=None):  
       shp_x = x.shape  
  
       if self.batch_dice:  
           axes = [0] + list(range(2, len(shp_x)))  
       else:  
           axes = list(range(2, len(shp_x)))  
  
       if self.apply_nonlin is not None:  
           x = self.apply_nonlin(x)  
  
       tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)  
  
  
       tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)  
  
       if not self.do_bg:  
           if self.batch_dice:  
               tversky = tversky[1:]  
           else:  
               tversky = tversky[:, 1:]  
       tversky = tversky.mean()  
  
       return -tversky  

7、Generalized Dice Loss

Dice loss虽然一定程度上解决了分类失衡的问题,但却不利于严重的分类不平衡。例如小目标存在一些像素的预测误差,这很容易导致Dice的值发生很大的变化。Sudre等人提出了Generalized Dice Loss (GDL)

GDL优于Dice损失,因为不同的区域对损失有相似的贡献,并且GDL在训练过程中更稳定和鲁棒。

8、Boundary Loss

为了解决类别不平衡的问题,Kervadec等人[95]提出了一种新的用于脑损伤分割的边界损失。该损失函数旨在最小化分割边界和标记边界之间的距离。作者在两个没有标签的不平衡数据集上进行了实验。结果表明,Dice los和Boundary los的组合优于单一组合。复合损失的定义为

其中第一部分是一个标准的Dice los,它被定义为

第二部分是Boundary los,它被定义为

9、Exponential Logarithmic Loss


在(9)中,加权Dice los实际上是得到的Dice值除以每个标签的和,对不同尺度的对象达到平衡。因此,Wong等人结合focal loss [96] 和dice loss,提出了用于脑分割的指数对数损失(EXP损失),以解决严重的类不平衡问题。通过引入指数形式,可以进一步控制损失函数的非线性,以提高分割精度。EXP损失函数的定义为

其中,两个新的参数权重分别用ωdice和ωcross表示。Ldice是指数对数骰子损失,而交叉损失是交叉熵损失

其中x是像素位置,i是标签,l是位置x处的地面真值。pi(x)是从softmax输出的概率值。
在(17)中,fk是标签k出现的频率,该参数可以减少更频繁出现的标签的影响。γDice和γcross都用于增强损失函数的非线性。

10.Focal Tversky Loss

与“Focal loss”相似,后者着重于通过降低易用/常见损失的权重来说明困难的例子。Focal Tversky Loss还尝试借助γ系数来学习诸如在ROI(感兴趣区域)较小的情况下的困难示例,如下所示:

class FocalTversky_loss(nn.Module):  
   """  
   paper: https://arxiv.org/pdf/1810.07842.pdf  
   author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65  
   """  
   def __init__(self, tversky_kwargs, gamma=0.75):  
       super(FocalTversky_loss, self).__init__()  
       self.gamma = gamma  
       self.tversky = TverskyLoss(**tversky_kwargs)  
  
   def forward(self, net_output, target):  
       tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)  
       focal_tversky = torch.pow(tversky_loss, self.gamma)  
       return focal_tversky  

11、Sensitivity Specificity Loss

首先敏感性就是召回率,检测出确实有病的能力:

640-8.png

特异性,检测出确实没病的能力:

640-9.png

而Sensitivity Specificity Loss为:

640-10.png
image.png
class SSLoss(nn.Module):  
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,  
                square=False):  
       """  
       Sensitivity-Specifity loss  
       paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf  
       tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392  
       """  
       super(SSLoss, self).__init__()  
  
       self.square = square  
       self.do_bg = do_bg  
       self.batch_dice = batch_dice  
       self.apply_nonlin = apply_nonlin  
       self.smooth = smooth  
       self.r = 0.1 # weight parameter in SS paper  
  
   def forward(self, net_output, gt, loss_mask=None):  
       shp_x = net_output.shape  
       shp_y = gt.shape  
       # class_num = shp_x[1]  
         
       with torch.no_grad():  
           if len(shp_x) != len(shp_y):  
               gt = gt.view((shp_y[0], 1, *shp_y[1:]))  
  
           if all([i == j for i, j in zip(net_output.shape, gt.shape)]):  
               # if this is the case then gt is probably already a one hot encoding  
               y_onehot = gt  
           else:  
               gt = gt.long()  
               y_onehot = torch.zeros(shp_x)  
               if net_output.device.type == "cuda":  
                   y_onehot = y_onehot.cuda(net_output.device.index)  
               y_onehot.scatter_(1, gt, 1)  
  
       if self.batch_dice:  
           axes = [0] + list(range(2, len(shp_x)))  
       else:  
           axes = list(range(2, len(shp_x)))  
  
       if self.apply_nonlin is not None:  
           softmax_output = self.apply_nonlin(net_output)  
         
       # no object value  
       bg_onehot = 1 - y_onehot  
       squared_error = (y_onehot - softmax_output)**2  
       specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)  
       sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)  
  
       ss = self.r * specificity_part + (1-self.r) * sensitivity_part  
  
       if not self.do_bg:  
           if self.batch_dice:  
               ss = ss[1:]  
           else:  
               ss = ss[:, 1:]  
       ss = ss.mean()  
  
       return ss

12、Log-Cosh Dice Loss

Dice系数是一种用于评估分割输出的度量标准。它也已修改为损失函数,因为它可以实现分割目标的数学表示。但是由于其非凸性,它多次都无法获得最佳结果。Lovsz-softmax损失旨在通过添加使用Lovsz扩展的平滑来解决非凸损失函数的问题。同时,Log-Cosh方法已广泛用于基于回归的问题中,以平滑曲线。

640.png
640-1.png

将Cosh(x)函数和Log(x)函数合并,可以得到Log-Cosh Dice Loss:

640-2.png
def log_cosh_dice_loss(self, y_true, y_pred):  
       x = self.dice_loss(y_true, y_pred)  
       return tf.math.log((torch.exp(x) + torch.exp(-x)) / 2.0)  

13、Hausdorff Distance Loss

Hausdorff Distance Loss(HD)是分割方法用来跟踪模型性能的度量。它定义为:

640-4.png

任何分割模型的目的都是为了最大化Hausdorff距离,但是由于其非凸性,因此并未广泛用作损失函数。有研究者提出了基于Hausdorff距离的损失函数的3个变量,它们都结合了度量用例,并确保损失函数易于处理。

class HDDTBinaryLoss(nn.Module):  
   def __init__(self):  
       """  
       compute haudorff loss for binary segmentation  
       https://arxiv.org/pdf/1904.10030v1.pdf          
       """  
       super(HDDTBinaryLoss, self).__init__()  
  
  
   def forward(self, net_output, target):  
       """  
       net_output: (batch_size, 2, x,y,z)  
       target: ground truth, shape: (batch_size, 1, x,y,z)  
       """  
       net_output = softmax_helper(net_output)  
       pc = net_output[:, 1, ...].type(torch.float32)  
       gt = target[:,0, ...].type(torch.float32)  
       with torch.no_grad():  
           pc_dist = compute_edts_forhdloss(pc.cpu().numpy()>0.5)  
           gt_dist = compute_edts_forhdloss(gt.cpu().numpy()>0.5)  
       # print('pc_dist.shape: ', pc_dist.shape)  
         
       pred_error = (gt - pc)**2  
       dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8)  
  
       dist = torch.from_numpy(dist)  
       if dist.device != pred_error.device:  
           dist = dist.to(pred_error.device).type(torch.float32)  
  
       multipled = torch.einsum("bxyz,bxyz->bxyz", pred_error, dist)  
       hd_loss = multipled.mean()  
  
       return hd_loss

总结:

交叉熵损失把每个像素都当作一个独立样本进行预测,而 dice loss 和 iou loss 则以一种更“整体”的方式来看待最终的预测输出。

这两类损失是针对不同情况,各有优点和缺点,在实际应用中,可以同时使用这两类损失来进行互补。

Deeplab v3

论文地址: https://arxiv.org/abs/1706.05587

Deeplab v3是v2版本的进一步升级,作者们在对空洞卷积重新思考的基础上,进一步对Deeplab系列的基本框架进行了优化,去掉了v1和v2版本中一直坚持的CRF后处理模块,升级了主干网络和ASPP模块,使得网络能够更好地处理语义分割中的多尺度问题。提出Deeplab v3的论文为Rethinking Atrous Convolution for Semantic Image Segmentation,是Deeplab系列后期网络的代表模型。

DeepLab V3的改进主要包括以下几方面:

1)提出了更通用的框架,适用于任何网络

2)复制了ResNet最后的block,并级联起来

3)在ASPP中使用BN层

4)去掉了CRF

随着语义分割的发展,逐渐有两大问题亟待解决:一个是连续的池化和卷积步长导致的下采样图像信息丢失问题,这个问题已经通过空洞卷积扩大感受野得到比较好的处理;另外一个则是多尺度和利用上下文信息问题。论文中分别回顾了四种基于多尺度和上下文信息进行语义分割的方法,如图1所示,包括图像金字塔、编解码架构、深度空洞卷积网络以及空间金字塔池化,这四种方法各有优缺点,ASPP可以算是深度空洞卷积和空间金字塔池化的一种结合,Deeplab v3在v2的ASPP基础上,进一步探索了空洞卷积在多尺度和上下文信息中的作用。

Deeplab v3可作为通用框架融入到任意网络结构中,具体地,以串行方式设计空洞卷积模块,复制ResNet的最后一个卷积块,并将复制后的卷积块以串行方式进行级联,如图2所示。DeepLab V3将空洞卷积应用在级联模块。具体来说,我们取ResNet中最后一个block,在下图中为block4,并在其后面增加级联模块。

在卷积块串行级联基础上,Deeplab v3又对ASPP模块进行并行级联,v3对ASPP模块进行了升级,相较于v2版本加入了批归一化(Batch Normalization,BN),通过组织不同的空洞扩张率的卷积块,同时加入图像级特征,能够更好地捕捉多尺度上下文信息,并且也能够更容易训练,如图3所示。

1)ASPP中应用了BN层

2)随着采样率的增加,滤波器中有效的权重减少了(有效权重减少,难以捕获原距离信息,这要求合理控制采样率的设置)

3)使用模型最后的特征映射的全局平均池化(为了克服远距离下有效权重减少的问题)

总结来看,Deeplab v3更充分的利用了空洞卷积来获取大范围的图像上下文信息。具体包括:多尺度信息编码、带有逐步翻倍的空洞扩张率的级联模块以及带有图像级特征的ASPP模块。实验结果表明,该模型在 PASCAL VOC数据集上相较于v2版本有了显着进步,取得了当时SOTA精度水平。

Deeplab v3的PyTorch实现可参考:

https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py

代码实现:

class DeeplabV3(ResNet):

    def __init__(self, n_class, block, layers, pyramids, grids, output_stride=16):
        self.inplanes = 64
        super(DeeplabV3, self).__init__()
        if output_stride == 16:
            strides = [1, 2, 2, 1]
            rates = [1, 1, 1, 2]
        elif output_stride == 8:
            strides = [1, 2, 1, 1]
            rates = [1, 1, 2, 2]
        else:
            raise NotImplementedError

        # Backbone Modules
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)   # h/4, w/4

        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0]) # h/4, w/4
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1]) # h/8, w/8
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2]) # h/16,w/16
        self.layer4 = self._make_MG_unit(block, 512, blocks=grids, stride=strides[3], rate=rates[3])# h/16,w/16

        # Deeplab Modules
        self.aspp1 = ASPP_module(2048, 256, rate=pyramids[0])  
        self.aspp2 = ASPP_module(2048, 256, rate=pyramids[1])
        self.aspp3 = ASPP_module(2048, 256, rate=pyramids[2])
        self.aspp4 = ASPP_module(2048, 256, rate=pyramids[3])

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(2048, 256, kernel_size=1, stride=1, bias=False),
                                             nn.BatchNorm2d(256),
                                             nn.ReLU())

        # get result features from the concat
        self._conv1 = nn.Sequential(nn.Conv2d(1280, 256, kernel_size=1, stride=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    nn.ReLU())

        # generate the final logits
        self._conv2 = nn.Conv2d(256, n_class, kernel_size=1, bias=False)

        self.init_weight()

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)

        # image-level features
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self._conv1(x)
        x = self._conv2(x)

        x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x

其中重要的_makelayer, _make_MG_unit和ASSP模块实现如下:

def _make_layer(self, block, planes, blocks, stride=1, rate=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, rate, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, rate=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, rate=blocks[0] * rate, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1, rate=blocks[i] * rate))

        return nn.Sequential(*layers)


class ASPP_module(nn.Module):
    def __init__(self, inplanes, planes, rate):
        super(ASPP_module, self).__init__()
        if rate == 1:
            kernel_size = 1
            padding = 0
        else:
            kernel_size = 3
            padding = rate
        self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=rate, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_convolution(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

训练策略:

Crop size:

  • 为了大采样率的空洞卷积能够有效,需要较大的图片大小;否则,大采样率的空洞卷积权值就会主要用于padding区域。
  • 在Pascal VOC 2012数据集的训练和测试中我们采用了513的裁剪尺寸。

Batch normalization:

  • 我们在ResNet之上添加的模块都包括BN层
  • 当output_stride=16时,采用batchsize=16,同时BN层的参数做参数衰减0.9997。
  • 在增强的数据集上,以初始学习率0.007训练30K后,冻结BN层参数,然后采用output_stride=8,再使用初始学习率0.001在PASCAL官方的数据集上训练30K。
  • 训练output_stride=16比output_stride=8要快很多,因为其中间的特征映射在空间上小四倍。但output_stride=16在特征映射上相对粗糙,快是因为牺牲了精度。

Upsampling logits:

  • 在先前的工作上,我们是将output_stride=8的输出与Ground Truth下采样8倍做比较。
  • 现在我们发现保持Ground Truth更重要,故我们是将最终的输出上采样8倍与完整的Ground Truth比较。

Data augmentation:

在训练阶段,随机缩放输入图像(从0.5到2.0)和随机左-右翻转

Deeplab v3+

论文地址 https://arxiv.org/abs/1706.05587

Deeplab v3+是Deeplab系列最后一个网络结构,也是基于空洞卷积和多尺度系列模型的集大成者。相较于Deeplab v3,v3+版本参考了UNet系列网络,对基于空洞卷积的Deeplab网络引入了编解码结构,一定程度上来讲,Deeplab v3+是编解码和多尺度这两大系列网络的一个大融合,在很长一段时间内代表了自然图像语义分割的SOTA水平的分割模型。提出Deeplab v3+的论文为Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation,至今仍然是最常用的一个语义分割网络模型。

对于语义分割问题,尽管各种网络模型很多,但Deeplab v3+的作者们认为迄今为止仅有两大主流设计:一个是以UNet为代表的编解码结构,另一个就是以Deeplab为代表的ASPP和多尺度设计,前者以获取图像中目标对象细致的图像边界见长,后者则更擅长捕捉图像中丰富的上下文多尺度信息。鉴于此,Deeplab v3+将编解码结构引入到v3网络中,Deeplab v3直接作为编码器,然后再加入解码器设计,这样就构成了一个SPP(空间金字塔池化)+Atrous Conv(空洞卷积)+Encoder-Decoder(编解码)包含众多丰富元素的组合结构。这三种结构如下图所示。

完整的Deeplab v3+网络结构如下图所示。可以看到,输入图像先进入有深度空洞卷积构成的编码器部分中,由并行的ASPP模块构建并且融入图像级的池化特征,然后再进行合并并通过1*1卷积降低通道数后得到编码器的输出,编码器部分没有做特别的改动,直接使用了Deeplab v3的结构作为编码器,旨在提取图像的多尺度上下文信息。解码器部分则是做了一些特别的设计,编码器输出先经过4倍双线性插值上采样,同时也从编码器部分链接浅层的图像特征到解码器部分,并经1*1卷积降维后与4倍上采样的输出进行拼接,最后经过一个3*3卷积后再经过一次4倍上采样后得到最终的分割输出结果。

另外,Deeplab v3+除了延用之前的ResNet系列作为backbone之外,也尝试了以轻量级着称的Xception网络,具体做法就是将深度可分离卷积(Depth separable convolution)引入到空洞卷积中,能够极大的减少计算量,虽然有一定的精度损失,但在追求速度性能上不失为一种非常好的选择。下面简单介绍一下深度可分离卷积。

从维度的角度看,卷积核可以看成是一个空间维(宽和高)和通道维的组合,而卷积操作则可以视为空间相关性和通道相关性的联合映射。从Inception的1*1卷积来看,卷积中的空间相关性和通道相关性是可以解耦的,将它们分开进行映射,可能会达到更好的效果。

深度可分离卷积是在1*1卷积基础上的一种创新。主要包括两个部分:深度卷积和1*1卷积。深度卷积的目的在于对输入的每一个通道都单独使用一个卷积核对其进行卷积,也就是通道分离后再组合。1*1卷积的目的则在于加强深度。下面以一个例子来看一下深度可分离卷积:假设我们用128个3*3*3的滤波器对一个7*7*3的输入进行卷积,可得到5*5*128的输出,其计算量为5*5*128*3*3*3=86400,如下图所示。

然后我们看如何使用深度可分离卷积来实现同样的结果。深度可分离卷积的第一步是深度卷积。这里的深度卷积,就是分别用3个3*3*1的滤波器对输入的3个通道分别做卷积,也就是说要做3次卷积,每次卷积都有一个5*5*1的输出,组合在一起便是5*5*3的输出。现在为了拓展深度达到128,我们需要执行深度可分离卷积的第二步:1*1卷积。现在我们用128个1*1*3的滤波器对5*5*3进行卷积,就可以得到5*5*128的输出。完整过程如下图所示。

看一下深度可分离卷积的计算量如何。第一步深度卷积的计算量:5*5*1*3*3*1*3=675。第二步11卷积的计算量:5*5*128*1*1*3=9600,合计计算量为10275次。可见,相同的卷积计算输出,深度可分离卷积要比常规卷积节省12倍的计算成本。典型的应用深度可分离卷积的网络模型包括Xception和MobileNet等。本质上而言,Xception就是应用了深度可分离卷积的Inception网络。

Deeplab v3+在PASCAL VOC和Cityscapes等公开数据集上均取得了SOTA的结果,即使在深度学习语义分割发展日新月异发展的今天,Deeplab v3+仍然不失为一个非常好的语义分割解决方案。

关于Deeplab系列各版本的技术要点总结如下表所示。

Deeplab v2

DeepLabv2:
DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs

TPAMI 2018

Deeplab v2 严格上算是Deeplab v1版本的一次不大的更新,在v1的空洞卷积和CRF基础上,重点关注了网络对于多尺度问题的适用性。多尺度问题一直是目标检测和语义分割任务的重要挑战之一,以往实现多尺度的惯常做法是对同一张图片进行不同尺寸的缩放后获取对应的卷积特征图,然后将不同尺寸的特征图分别上采样后再融合来获取多尺度信息,但这种做法最大的缺点就是计算开销太大。Deeplab v2借鉴了空间金字塔池化(Spatial Pyramaid Pooling, SPP)的思路,提出了基于空洞卷积的空间金字塔池化(Atrous Spatial Pyramaid Pooling, ASPP),这也是Deeplab v2最大的亮点。提出Deeplab v2的论文为DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs,是Deeplab系列网络中前期结构的重要代表。

ASPP来源于R-CNN(Regional CNN)目标检测领域中SPP结构,该方法表明任意尺度的图像区域可以通过对单一尺度提取的卷积特征进行重采样而准确有效地分类。ASPP在其基础上将普通卷积改为空洞卷积,通过使用多个不同扩张率且并行的空洞卷积进行特征提取,最后在对每个分支进行融合。ASPP结构如下图所示。

除了ASPP之外,Deeplab v2还将v1中VGG-16的主干网络换成了ResNet-101,算是对编码器的一次升级,使其具备更强的特征提取能力。Deeplab v2在PASCAL VOC和Cityscapes等语义分割数据集上均取得了当时的SOTA结果。关于ASPP模块的一个简单实现参考如下代码所示,先是分别定义了ASPP的卷积和池化方法,然后在其基础上定义了ASPP模块。


### 定义ASPP卷积方法
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation,
                      dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        super(ASPPConv, self).__init__(*modules)

  ### 定义ASPP池化方法
class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode='bilinear',
 align_corners=False)

### 定义ASPP模块
class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)))

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),)

  # ASPP前向计算流程
    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

下图是Deeplab v2在Cityscapes数据集上的分割效果:

Deeplab v1

DeepLabv1:
Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs

ICLR 2015

在语义分割发展早期,一些研究观点认为将CNN用于图像分割主要存在两个问题:一个是下采样导致的信息丢失问题,另一个则是CNN的空间不变性问题,这与CNN本身的特性有关,这种空间不变性有利于图像分类但却不利于图像分割中的像素定位。从多尺度和上下文信息的角度来看,这两个问题是导致FCN分割效果有限的重要原因。因而,相关研究针对上述两个问题提出了Deeplab v1网络,通过在常规卷积中引入空洞(Atrous)和对CNN分割结果补充CRF作为后处理来优化分割效果。提出Deeplab v1的论文为Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs,是Deeplab系列的开篇之作。

针对第一个问题,池化下采样操作引起信息丢失,Deeplab v1给出的解决方案算是另辟蹊径。常规卷积中,使用池化下采样的主要目的是增大每个像素的感受野,但在Deeplab v1中,作者们的想法是可以不用池化也可以增大像素的感受野,尝试在卷积操作本身上重新进行设计。在Deeplab v1,一种在常规卷积核中插入空洞的设计被提出,相较于池化下采样,空洞卷积能够在不降低图像分辨率的情况下扩大像素感受野,从而就避免了信息损失的问题

空洞卷积(Dilated/Atrous Convolution)也叫扩张卷积或者膨胀卷积,字面意思上来说就是在卷积核中插入空洞,起到扩大感受野的作用。空洞卷积的直接做法是在常规卷积核中填充0,用来扩大感受野,且进行计算时,空洞卷积中实际只有非零的元素起了作用。假设以一个变量a来衡量空洞卷积的扩张系数,则加入空洞之后的实际卷积核尺寸与原始卷积核尺寸之间的关系:

K=k+(k-1)(a-1)

其中为k原始卷积核大小,a为空洞率(Dilation Rate),K为经过扩展后实际卷积核大小。除此之外,空洞卷积的卷积方式跟常规卷积一样。当a=1时,空洞卷积就退化为常规卷积。a=1,2,4时,空洞卷积示意图如下图所示。

对于语义分割而言,空洞卷积主要有三个作用:

第一是扩大感受野,具体前面已经说的比较多了,这里不做重复。但需要明确一点,池化也可以扩大感受野,但空间分辨率降低了,相比之下,空洞卷积可以在扩大感受野的同时不丢失分辨率,且保持像素的相对空间位置不变。简单而言就是空洞卷积可以同时控制感受野和分辨率。

第二就是获取多尺度上下文信息。当多个带有不同空洞率的空洞卷积核叠加时,不同的感受野会带来多尺度信息,这对于分割任务是非常重要的。

第三就是可以降低计算量,不需要引入额外的参数,如图4-13所示,实际卷积时只有带有红点的元素真正进行计算。

针对第二个问题,Deeplab v1通过引入全连接的CRF来对CNN的粗分割结果进行优化。CRF作为一种经典的概率图模型,可用于图像像素之间的关系描述,在传统图像处理中主要用于图像平滑处理。但对于CNN分割问题来说,使用短程的CRFs可能会于事无补,因为分割问题的目标是恢复图像的局部细节信息,而不是对图像做平滑处理。所以Deeplab v1提出的解决方案叫做全连接CRF(Fully Connected CRF)。

条件随机场可以优化物体的边界,平滑带噪声的分割结果,去掉物体中间的预测的孔洞,使得分割结果更加准确。

CRF是一种经典的概率图模型,简单而言就是给定一组输入序列的条件下,求另一组输出序列的条件概率分布模型,CRF在自然语言处理领域有着广泛应用。CRF在语义分割后处理中用法的基本思路如下:对于FCN或者其他分割网络的粗粒度分割结果而言,每个像素点i具有对应的类别标签x_i和观测值y_i,以每个像素为节点,以像素与像素之间的关系作为边即可构建一个CRF模型。在这个CRF模型中,我们通过观测变量y_i来预测像素i对应的标签值x_i。CRF用于像素预测的结构如下图所示。

全连接CRF使用的能量函数为:

Deeplab v1模型流程如下图所示。输入图像经过深度卷积网络(DCNN)后生成浓缩的、粗粒度的语义特征图,再经过双线性插值上采样后形成粗分割结果,最后经全连接的CRF后处理生成最终的分割结果图。

下图是Deeplab v1在无CRF和有CRF后处理的分割效果图:

总的来看,Deeplab v1有如下几个有点:

(1)速度快。基于空洞卷积的CNN分割网络,能够保证8秒每帧(Frame Per Second,FPS)的推理速度,后处理的全连接CRF也仅需要0.5秒。

(2)精度高。Deeplab v1在当时取得了在PASCAL VOC数据集上SOTA的分割模型表现,准确率超过此前SOTA的7.2%。

(3)简易性。DCNN和CRF均为成熟的算法模块,将其进行简单的级联即可在当前的语义分割模型上取得好的效果。

Deeplab v1 PyTorch参考实现代码如下:

https://github.com/wangleihitcs/DeepLab-V1-PyTorch

PSPNet:Pyramid Attention Networkfor Semantic Segmentation

原文地址:https://arxiv.org/pdf/1612.01105.pdf
论文代码:https://github.com/hszhao/PSPNet

解读:
本文的主要创新点在于提出了空间金字塔池化模块,简单来说就是在encoder之后得到feature map X,使用不同尺寸的kernel进行池化(avepooling)操作,之后对得到的feature map进行上采样,使得尺寸和X大小一样,然后进行级联,进而经过卷积操作得到map。和经典的FCN的区别是在encoder和decoder之间加入了PSP模块。

场景解析(scene parsing)是语义分割的一个重要应用方向,区别于一般的语义分割任务,场景解析需要在复杂的自然图像场景下对更庞大的物体类别的每一个像素进行分类,场景解析在自动驾驶和机器人感知等方向应用广泛。但由于自然场景的复杂性、语义标签的多样性以及目标物体的多变性,对于场景解析问题的研究一直存在诸多困难。

场景解析一般基于FCN和空洞卷积网络来进行结构设计,后续的改进方案主要有两个方向,一种是多尺度特征集成,学界普遍的观点认为深层特征能够提取图像的语义信息但缺乏定位信息,将多尺度的特征集成起来能够显著提升模型效果。另一个则是结合图模型的结构预测,比如说使用CRF进行后处理来优化图像分割结果,这些改进方案虽然在一定程度上都能缓解像素定位问题,但对于更为复杂的图像场景仍然效果有限。将全局信息加入到语义分割网络中的设计也有很多研究提到,比如ParseNet,通过给FCN补充全局平均池化来提升分割效果。但这种全局信息不足以在一些复杂的场景解析数据集上体现效果,比如包括150个语义类别的ADE20K数据集。总结来看,FCN对于复杂的场景解析存在如下三个问题:

(1)像素上下文关系不匹配。在语义分割中,图像上下文信息非常重要,不同的像素类别之间往往存在着共现的视觉模式,比如说飞机通常只会出现在跑道上或者空中,而不会出现在马路上,所以飞机和跑道以及天空这三个对象之间存在着共现模式。

(2)语义标签的混淆。ADE20K数据集中有大量语义相近的类别,比如高山与丘陵、建筑与摩天大楼等。FCN在这种容易混淆的类别上很难做出准确的判断。

(3)小目标类别预测困难。在场景解析中,目标物体可以是任意大小,所以有时候对于小目标物体预测效果就很差,比如说对路灯和广告牌的识别,FCN表现就非常不好。

针对上述问题,相关学者基于空间金字塔池化(Spatial Pyramid Pooling, SPP)提出了一种更为广泛的、基于不同图像区域的全局上下文信息集成结构:PSPNet,提出PSPNet的论文为Pyramid Scene Parsing Network,是多尺度上下文结构方向的一个重要的网络设计。PSPNet网络结构如图1所示。

可以看到,将输入图像(a)经过CNN网络提取之后的特征图(b)送入到一个金字塔池化模块(Pyramid Pooling Module,PPM),在该模块中(c),使用四种不同尺度的全局池化(1*1、2*2、3*3和6*6)进行下采样并结合1*1卷积来生成降维后的上下文信息表征。然后直接使用双线性插值对低分辨率的特征图进行上采样,并与阶段(b)的特征图进行连接,形成最终的全局金字塔池化特征,最后再进行一组卷积后形成最终的语义分割预测结果(d)。基于上述结构,PSPNet对于场景解析任务能够提供有效的全局上下文先验信息,PPM相较于直接的全局池化,能够收集多尺度的信息表征,并且在计算开销上相比于FCN也增加不多。此外PSPNet训练时为防止梯度消失还添加了辅助损失函数作为一种深监督机制。

 PSPNet在ADE20K、PASCAL VOC 2012以及Cityscapes数据集上分别进行了测试,均能达到当时SOTA的分割水平。另外也在辅助损失函数和预训练模型等方面做了充分的消融实验(ablation study)。PSPNet在ADE20K数据集上的一组测试效果如图2所示。在三张测试图像上,FCN不能够捕捉图像上下文关系(将第一张图像中水面上的船只识别为汽车),但PSPNet均能够得到很好的分割效果。

网络结构

  • encoder。使用了预训练的ResNet,里面使用了孔洞卷积(后面几层没有下采样,全部使用空洞卷积)。最后输出的feature map是原图的1/8。类似下图

在encoder之后,使用了金字塔池化模块,使用了四种尺寸的金字塔,池化所用的kerne分别1×1, 2×2, 3×3 and 6×6。池化之后上采样,然后将得到的feature map,包括池化之前的做一个级联(concatenate),后面接一个卷积层得到最终的预测图像。作者提到,和全局池化相比,金字塔池化能提取多尺度的信息。

下述代码给出了PSPNet网络结构的一个简单实现流程。


### 定义PPM模块类
class _PyramidPoolingModule(nn.Module):
    def __init__(self, in_dim, reduction_dim, setting):
        super(_PyramidPoolingModule, self).__init__()
        self.features = []
        for s in setting:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(s),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim, momentum=.95),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

  # PPM前向计算流程
    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
        out = torch.cat(out, 1)
        return out

  ### 定义PSPNet类
class PSPNet(nn.Module):
    def __init__(self, num_classes, use_aux=True):
        super(PSPNet, self).__init__()
        self.use_aux = use_aux
    # 使用ResNet101作为预训练模型
        resnet = models.resnet101()
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1,1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1,1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
    
        self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
        self.final = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )
    # 深监督辅助损失
        if use_aux:
            self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
            initialize_weights(self.aux_logits)
        initialize_weights(self.ppm, self.final)

  ### PSPNet前向计算流程
    def forward(self, x):
        x_size = x.size()
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.training and self.use_aux:
            aux = self.aux_logits(x)
        x = self.layer4(x)
        x = self.ppm(x)
        x = self.final(x)
        if self.training and self.use_aux:
            return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
        return F.upsample(x, x_size[2:], mode='bilinear')

Unet++

代码:https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py

论文:https://arxiv.org/abs/1807.10165

论文作者的知乎讲解:https://zhuanlan.zhihu.com/p/44958351

UNet的编解码结构一经提出以来,大有统一深度学习图像分割之势,后续基于UNet的改进方案也经久不衰,一些研究者也在从网络结构本身来思考UNet的有效性。比如说编解码网络应该取几层,跳跃连接是否能够有更多的变化以及什么样的结构训练起来更加有效等问题。UNet本身是针对医学图像分割任务而提出来的网络结构,该任务不像自然图像分割,对分割精度要求并不是十分严格。但对于医学图像而言,器官和病灶的分割则要求极高的精确性,因为很多时候分割效果的好坏直接关系到对应的临床诊断决策。出于上述两个方面的动机,即设计更好的UNet结构和提升医学图像分割的精度,相关研究者提出了一种嵌套的UNet结构(Nested UNet),也叫UNet++,提出UNet++的论文为UNet++: A Nested U-Net Architecture for Medical Image Segmentation,发表于2018年的医学图像计算和计算机辅助干预(Medical Image Computing and Computer Assisted Intervention,MICCAI)会议上。

UNet++取名为嵌套的UNet,就在于其整体编解码网络结构中还嵌套了编解码的子网络(sub-networks),在此基础上重新设计UNet中间的跳跃连接,并补充了深监督机制加速网络训练收敛。完整的UNet++结构如下图所示。

图中黑色部分为原始的UNet结构,包括编码器下采样、解码器上采样和黑色虚线的跳跃连接三个部分;绿色部分即嵌套的UNet子网络,包括卷积和上采样两部分,而蓝色虚线部分就是UNet++重新设计后的跳跃连接,这部分跟DenseNet的密集连接类似,这里是为子网络提供跳跃连接;最上面红黑连线则是UNet++补充的深监督机制,目的是为了网络能够顺利得到训练。

下面我们从结构设计的角度来对UNet++进行解读。关于UNet结构,最首要的问题就是网络应该有几层,原始的UNet结构用了4层下采样和4层上采样,那么是不是4层就足以满足所有的分割任务需要?答案是否定的。通过本节之前的网络结构分析,我们已经知道,浅层网络能够提取图像粗粒度特征,获取图像基本形态;深层网络能够提取图像的抽象特征,获取图像语义信息,总之浅有浅的侧重,深有深的好处。同之前RefineNet的观点一样,UNet++的作者认为,不管是浅层、深层还是中层,所有层次的特征对于最后的分割都是重要的。有的数据分割任务简单,图像信息单一,可能浅层网络就足以达到很好的效果,而有的数据任务复杂,图像信息丰富,可能需要更深层的网络结构才能达到不错的效果,之前的UNet结构设计很难同时照顾到这种普适性。而UNet++通过设计不同深度的嵌套UNet子网络来实现这种普适性,所以UNet的深度到这里就解决了。

第二个问题则是加入不同深度的嵌套网络后,跳跃连接部分该如何调整。在UNet中,跳跃连接由同层编码器直连到编码器上采样对应层。但加入嵌套子网络后,UNet中原先的长连接就不复存在了,取而代之的是各子网络中的短连接。UNet++的作者们认为,长连接在UNet中是有必要的,能够将图像中前后信息联系起来,对于下采样造成的信息损失有很好的补充作用。所以,UNet++又参考DenseNet的密集连接设计,给嵌套网络补充了长连接,如上图所示。

但是这样又带来了第三个问题:反向传播的时候中间部分可能会收不到由损失函数反传回来的梯度。所以见招拆招,UNet++又通过深监督的方法来强行加梯度,帮助网络正常进行训练。但深监督对于UNet++的好处绝不仅仅限于此,通过不同深监督损失函数,UNet++可以通过网络剪枝来实现可伸缩性。所以,总结来说UNet++相较于原始的UNet,有如下两个优势:

(1)通过嵌套子网络和长短连接来整合不同层次的图像特征,使得网络分割精度更高;

(2)灵活的网络结构配合深监督机制,让参数量巨大的深度网络在可接受的精度范围内能够大幅度的缩减参数量。

UNet++与UNet等网络分割效果对比如下图所示。

UNet++也进一步壮大了UNet家族网络,后续基于其的改进版本也有很多,比如Attention UNet++、UNet 3+等。下述代码给出了UNet++的一个实现参考。


class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output

最后,给出作者在知乎上的一段总结:

回顾一下这次分享,我问了好多问题,也提供了其中一些我个人给出的解释。

一顿听下来,热心的网友可能会懵了,就这么个分割网络,都能说这么久,要我说就放个结构图,说这个网络很牛逼,再告知一下代码在哪儿,谢谢大家就完事儿了。

其实搞学术做研究不是这样子的,UNet++肯定马上就会被更强的结构所代替,但是要设计出更强的结构,你得首先明白这个结构,甚至它的原型U-Net设计背后的心路历程。与其和大家分享一个苍白的分割网络,我更愿意分享的是这个项目背后从开始认识U-Net,到分析它的组成,到批判性的解读,再到改进思路的形成,实验设计,像刚刚分享过程中一次次尴尬的自问自答,中间那些非常饱满的心路历程。这也是我在博士的两年中学到的做研究的范式。

说句题外话,就跟玩狼人杀一样,你一上来就说自己预言家,验了谁谁谁,那没人信的呀,你得说清楚为什么要验他,警徽怎么留,把这些心路历程都盘盘清楚,身份才能做实。

我在微博上也看到有人说,也想到了非常类似的结构,实验也快做完了,看到了我这篇论文心就凉了。其实网络结构怎么样真的不重要,重要的是你怎么能把故事给讲清楚,要是讲完以后还能够引起更多的思考和讨论那就更好了。我在分享中提到了很多我们论文中的不足之处,也非常欢迎大家可以批评指正。