Rotamer-Free Protein Sequence Design Based on Deep Learning and Self-Consistency

论文地址: https://www.nature.com/articles/s43588-022-00273-6

中国科大用深度学习实现高实验成功率的蛋白质序列从头设计

中国科学技术大学生命科学与医学部刘海燕教授、陈泉副教授团队与信息科学技术学院李厚强教授团队合作,开发了一种基于深度学习为给定主链结构从头设计氨基酸序列的算法ABACUS-R,在实验验证中,ABACUS-R的设计成功率和设计精度超过了原有统计能量模型ABACUS。相关成果以“Rotamer-Free Protein Sequence Design Based on Deep Learning and Self-Consistency”为题于北京时间2022年7月21日发表于Nature Computational Science。

刘海燕教授、陈泉副教授团队致力于发展数据驱动的蛋白质设计方法,建立并实验验证了利用神经网络能量函数从头设计主链结构的SCUBA模型,以及对给定主链结构设计氨基酸序列的统计能量函数ABACUS。然而,通过优化能量函数来进行序列设计的方法在成功率、计算效率等方面仍有不足。近期有多项研究表明,用深度学习进行氨基酸序列设计能够在天然氨基酸残基类型恢复率等计算指标上超过能量函数方法;但截至目前已正式发表的工作中,对相关方法的实验验证结果远未达到能量函数方法的成功率。该论文报道的ABACUS-R模型,则不仅在计算指标上超过ABACUS,在实验验证中成功率和结构精度也有大幅提高。

用ABACUS-R进行序列设计的方法由两部分组成(图1)。第一部分为预训练的编码器-解码器网络:该网络用Transformer把中心氨基酸残基的化学和空间结构环境映射为隐空间表示向量,再用多层感知机网络将该向量解码为包括中心残基氨基酸类型在内的多种真实特征(图1a)。在方法的第二部分,经用非冗余天然蛋白序列结构数据训练后,ABACUS-R编码器-解码器被用于给定主链结构的全部或部分氨基酸序列从头设计。具体为:从任意初始序列出发,对各个类型待定残基分别应用ABACUS-R编码器-解码器,得到环境依赖的最适宜残基类型,并反复迭代至不同位点的残基类型最大程度自洽(图1b)。

图1. 用ABACUS-R模型进行蛋白质序列设计的原理。(a) 预训练的编码器-解码器网络;(b)采用自洽迭代策略进行全序列从头设计。

ABACUS-R方法包含两部分:(1)一个encoder-decoder网络被预训练用以推断给定骨架的局部环境时中心残基的侧链类型 (2)用该encoder-decoder网络连续更新每个残基的类型,最终收敛获得自洽(self-consistent)。网络的输入是中心残基与空间上最邻近(Cα间距离)k个残基组成的局部结构。邻近残基的特征包含空间层面的相对位置与取向信息(XSPA)、序列层面的相对位置信息(XRSP)以及邻近残基的残基类型(XAA)。第i个中心残基的特征包含全零的XSPA、被mask的XAA以及骨架上的15个ϕi−2ψi−2ωi−2 ⋯ ϕi+2ψi+2ωi+2,这些特征组合起来会被映射到与邻近残基特征相同的维度。以上模型输入的信息都是旋转平移不变的。局部结构中的所有残基的特征经过可学习的映射后融合后,得到每个残基总特征En。{En; n = 0, 1, 2, … , k}经过基于 transformer架构的encoder-decoder,预测每个中心残基的类型以及其他辅助任务。

自洽迭代设计的方法是:对序列随机初始化,第一轮随机选择80%的残基通过encoder-decoder并行预测其残基类型,以后每轮随机选择的残基数目逐渐下降。最终的设计结果会逐渐收敛。

在理论验证的基础上,中国科大团队尝试了实验表征用ABACUS-R对3个天然主链结构重新设计的57条序列;其中86%的序列(49条)可溶表达并能折叠为稳定单体;实验解析的5个高分辨晶体结构与目标结构高度一致(主链原子位置均方根位移在1Å以下)(图2)。此外,与以前报道的从头设计蛋白相似,ABACUS-R从头设计的蛋白表现出超高热稳定性,去折叠温度大多可达100℃以上。

作者将PDB中的非冗余结构按照两种不同的方式划分了95%作为训练集、5%作为测试集,第一种划分方式确保测试集的结构不会存在训练集中出现过的CATH拓扑,训练得到的模型为Model­­eval;第二种划分方式时随机划分Modelfinal。Model­­eval可以用来评估模型能力的无偏向性的表现,而Modelfinal使用了更丰富的数据训练表现应当更好。

表现评估

Encoder-decoder的架构可以进行多任务学习,除了训练序列的恢复的任务以外,还可以预测二级结构、SASA、B-factor与侧链扭转角χ1、χ2。多个任务可以增强模型设计序列的能力(图2a),Model­­eval与Model­­final都可以在测试集上最好取得50%左右准确度。在测试集上的结果显示,虽然有些残基类型没有恢复正确,但是模型也学习到了替换为性质相似的残基(图2b)。

相较于ABACUS模型,ABACUS-R序列设计更高的成功率和结构精度进一步增强了数据驱动蛋白质从头设计方法的实用性。ABACUS-R还提供了一种对蛋白质局部结构信息的预训练表示方式,可用于序列设计以外的其他任务。

Decoder网络输出的是每个位置上残基类型的-logP,类似于选择不同残基对应的能量,所以作者将ProTherm数据集中蛋白突变的ΔΔG与模型计算出相应的−ΔΔlogits进行了比较,发现二者有一定的相关性(图2d),说明模型一定程度上学习到了能量。

接着,作者验证了模型的自洽性,测试集中100个蛋白属于CATH的三个大类,对其中的每个蛋白从随机序列出发设计10条序列,随着迭代的次数变多,平均-logP会趋于收敛(图3a),同时未收敛的残基比例也会收敛(图3b)。不同CATH类别的骨架上取得的序列恢复率差距不大(图3c)。同一蛋白骨架设计出的序列会有很高的相似性(0.76-0.89)。设计出的序列与天然序列相比,序列的成分高度相似(图3d),Pearson相关系数达到了0.93,但GLU、ALA与LYS出现得更频繁,而Gln、His、Met出现得更少。此外,ABACUS-R设计出的序列与ABACUS设计出的序列相比,平均每个残基的Rosetta打分更低(图3e),而平均的-logP打分却更高(图3f),这意味着ABACUS-R学习到的能量与Rosetta打分函数存在正交的部分。

图3. ABACUS-R的自洽能力、设计能力以及学习到的能量与Rosetta打分的比较

相较于其他深度学习方法在单个残基恢复任务上的表现,ABACUS-R超过了除DenseCPD外的所有方法(表1),在整条序列重设计任务上ABACUS-R在两个测试集上都取得了最好的表现

实验验证

最后,作者在3种天然骨架(PDB ID: 1r26, 1cy5 and 1ubq)上通过实验验证了ABACUS-R的设计能力。设计的方法有两种:第一种采用迭代自洽的设计方法(生成序列的多样性低),第二种采用迭代时对decoder输出结果进行采样(生成序列的多样性高,但-logP能量也略高)。

第一种方法设计的27条序列有26条成功表达,体积排阻色谱与1H NMR实验结果显示所有的蛋白都以单体形式存在,示差扫描量热实验显示5条序列有很好的热稳定性( 97~117 C )。最终,1r26的3个设计与1cy5的1个设计成功解出了晶体结构,Cα RMSD位于0.51~0.88 Å,而1ubq的1个设计虽然没有解出结构,但已有的实验结果显示它折叠成了明确的三维结构。

第二种方法对同一骨架设计的序列相似度在58%左右。30条设计的序列中,25条被成功表达,23条能被可溶地纯化。所有设计同样都是单体存在并且折叠成了明确的三维结构,5个设计有很好的热稳定性(85~118 C)。最终,1r26的1个设计被成功解出了晶体结构,Cα RMSD为0.67 Å。相较方法一的自洽设计,方法二设计成功率下降,成功设计的蛋白热稳定性也略微下降,但作者认为可以接受。

最后,作者展示了所有1r26设计晶体结构核心的侧链pack(图4a,b),以及 1cy5设计晶体结构的侧链的极性作用(图4c),说明了ABACUS-R学会了设计侧链的组合以pack好的结构。

深度学习语义分割理论与实战指南

看到github中一个写的很棒的图像分割的介绍和代码实现:

https://github.com/luwill/Semantic-Segmentation-Guide

对于刚入门的同学来说,十分友好。

引言

图像分类、目标检测和图像分割是基于深度学习的计算机视觉三大核心任务。三大任务之间明显存在着一种递进的层级关系,图像分类聚焦于整张图像,目标检测定位于图像具体区域,而图像分割则是细化到每一个像素。基于深度学习的图像分割具体包括语义分割、实例分割和全景分割。语义分割的目的是要给每个像素赋予一个语义标签。语义分割在自动驾驶、场景解析、卫星遥感图像和医学影像等领域都有着广泛的应用前景。本文作为基于PyTorch的语义分割技术手册,对语义分割的基本技术框架、主要网络模型和技术方法提供一个实战性指导和参考。

1. 语义分割概述

图像分割主要包括语义分割(Semantic Segmentation)和实例分割(Instance Segmentation)。那语义分割和实例分割具体都是什么含义?二者又有什么区别和联系?语义分割是对图像中的每个像素都划分出对应的类别,即实现像素级别的分类;而类的具体对象,即为实例,那么实例分割不但要进行像素级别的分类,还需在具体的类别基础上区别开不同的个体。例如,图像有多个人甲、乙、丙,那边他们的语义分割结果都是人,而实例分割结果却是不同的对象。另外,为了同时实现实例分割与不可数类别的语义分割,相关研究又提出了全景分割(Panoptic Segmentation)的概念。语义分割、实例分割和全景分割具体如图1(b)、(c)和(d)图所示。

Fig1. Image Segmentation
在开始图像分割的学习和尝试之前,我们必须明确语义分割的任务描述,即搞清楚语义分割的输入输出都是什么。输入是一张原始的RGB图像或者单通道图像,但是输出不再是简单的分类类别或者目标定位,而是带有各个像素类别标签的与输入同分辨率的分割图像。简单来说,我们的输入输出都是图像,而且是同样大小的图像。如图2所示。 

Fig2. Pixel Representation
类似于处理分类标签数据,对预测分类目标采用像素上的one-hot编码,即为每个分类类别创建一个输出的通道。如图3所示。 

Fig3. Pixel One-hot
图4是将分割图添加到原始图像上的叠加效果。这里需要明确一下mask的概念,在图像处理中我们将其译为掩码,如Mask R-CNN中的Mask。Mask可以理解为我们将预测结果叠加到单个通道时得到的该分类所在区域。 

Fig4. Pixel labeling
所以,语义分割的任务就是输入图像经过深度学习算法处理得到带有语义标签的同样尺寸的输出图像。

。。。。。。。(剩余部分在github可查看)

Attention UNet

论文: https://arxiv.org/abs/1804.03999

以CNN为基础的编解码结构在图像分割上展现出了卓越的效果,尤其是医学图像的自动分割上。但一些研究认为以往的FCN和UNet等分割网络存在计算资源和模型参数的过度和重复使用,例如相似的低层次特征被级联内的所有网络重复提取。针对这类普遍性的问题,相关研究提出了给UNet添加注意力门控(Attention Gates, AGs)的方法,形成一个新的图像分割网络结构:Attention UNet。提出Attention UNet的论文为Attention U-Net: Learning Where to Look for the Pancreas,发表在2018年CVPR上。注意力机制原先是在自然语言处理领域被提出并逐渐得到广泛应用的一种新型结构,旨在模仿人的注意力机制,有针对性的聚焦数据中的突出特征,能够使得模型更加高效。

Attention UNet的网络结构如下图所示,需要注意的是,论文中给出的3D版本的卷积网络。其中编码器部分跟UNet编码器基本一致,主要的变化在于解码器部分。其结构简要描述如下:编码器部分,输入图像经过两组3*3*3的3D卷积和ReLU激活,然后再进行最大池化下采样,经过3组这样的卷积-池化块之后,网络进入到解码器部分。编码器最后一层的特征图除了直接进行上采样外,还与来自编码器的特征图进行注意力门控计算,然后再与上采样的特征图进行合并,经过三次这样的上采样块之后即可得到最终的分割输出图。相比于普通UNet的解码器,Attention UNet会将解码器中的特征与编码器连接过来的特征进行注意力门控处理,然后再与上采样进行拼接。经过注意力门控处理后得到的特征图会包含不同空间位置的重要性信息,使得模型能够重点关注某些目标区域。

我们将Attention UNet的注意力门控单独拿出来进行分析,看AGs是如何让模型能够聚焦到目标区域的。如图中上图所示,将Attention UNet网络中的一个上采样块单独拿出来,其中x_l为来自同层编码器的输出特征图,g表示由解码器部分用于上采样的特征图,这里同时也作为注意力门控的门控信号参数与x_l的注意力计算,而x^hat_l即为经过注意力门控计算后的特征图,此时x^hat_l是包含了空间位置重要性信息的特征图,再将其与下一层上采样后的特征图进行合并才得到该上采样块最终的输出。

将x_l和g_i计算得到的注意力系数再次与x_l相乘即可得到x^hat_l,这种经过与注意力系数相乘后的特征图会让图像中不相关的区域值变小,目标区域的值相对会变大,提升网络预测速度同时,也会提高图像的分割精度。论文中的各项实验结果也表明,经过注意力门控加成后后UNet,效果均要优于原始的UNet。下述代码给出了Attention UNet的一个2D参考实现,并且下采样次数由论文中的3次改为了4次。


### 定义Attention UNet类
class Att_UNet(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(Att_UNet, self).__init__()
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)

        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 =
       nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
    
  ### 定义前向传播流程
    def forward(self,x):
        # 编码器部分
        x1 = self.Conv1(x)
        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)
        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)
        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # 解码器+连接部分
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)
        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)
        d1 = self.Conv_1x1(d2)
        return d1
  
  ### 定义Attention门控块
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
    # 注意力门控向量
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int,
            kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
            )
        # 同层编码器特征图向量
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int,
            kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )
    # ReLU激活函数
    self.relu = nn.ReLU(inplace=True)
    # 卷积+BN+sigmoid激活函数
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1,
            kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
    ###  Attention门控的前向计算流程 
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        return x*psi

总结来说,Attention UNet提出了在原始UNet基础添加注意力门控单元,注意力得分能够使得图像分割时聚焦到目标区域,该结构作为一个通用结构可以添加到任何任务类型的神经网络结构中,在语义分割网络中对前景目标区域的像素更具有敏感度。Attention UNet壮大了UNet家族网络,此后基于其的改进版本也层出不穷。

SegNet

论文(2015):SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

Github:https://github.com/alexgkendall/caffe-segnet

  • 把本文提出的架构和FCN、DeepLab-LargeFOV、DeconvNet做了比较,这种比较揭示了在实现良好分割性能的前提下内存使用情况与分割准确性的权衡。
  • SegNet的主要动机是场景理解的应用。因此它在设计的时候考虑了要在预测期间保证内存和计算时间上的效率。
  • 定量的评估表明,SegNet在和其他架构的比较上,时间和内存的使用都比较高效。

SegNet论文提出了max pooling的改进版,使用该pooling操作既可以进行下采样操作,也可以进行上采样操作。在下采样操作中同时输出pooling后的结果和pooling过程中的索引。在上采样操作中,利用下采样对应位置的索引,进行上采样操作,这样的优势在于记住了最亮特征像素的空间位置。(去除了unet里面的反卷积操作)

优点,

  1. 可以提高物体边界的分割效果
  2. 相比反卷积操作,减少了参数数量,减少了运算量,相比resize操作,减少了插值的运算量,而实际增加的索引参数也很少。
  3. 该pooling操作可以应用于任何基于编码-解码的分割模型。

SegNet网络结构如下图所示,是一个编解码完全对称的结构。其编码器直接用了VGG16的结构,并将全连接层全部改为卷积层,实际训练时可使用VGG16的预训练权重进行初始化;编码器将13层卷积层分为5组卷积块,每组卷积块之间用最大池化层进行下采样。作为一个对称结构,SegNet解码器也有13层卷积层,同样分为5组卷积块,每组卷积块之间用双线性插值和最大池化位置索引进行上采样,这也是SegNet最大的特色。

SegNet研究团队认为编码器下采样过程中图像信息损失较多,直接存储所有卷积块的特征图又非常占用内存,因而在SegNet中提出在每一次最大池化下采样前存储最大池化的位置索引(Max-pooling indices),即记住最大池化操作中,最大值在2*2池化窗口中的位置。每个2*2窗口仅需要2 bits内存存储量,这种池化位置索引可用于上采样解码时恢复图像信息。下图给出了SegNet与FCN之间的上采样方法对比。可以观察到,SegNet使用双线性插值并结合最大池化位置索引进行上采样,而FCN则是基于去卷积结合编码器卷积特征图进行上采样。

SegNet这种轻量化的上采样方式,不仅能够提升图像边界分割效果,在端到端的实时分割项目中速度也非常快,并且这种结构设计可以配置到任意的编解码网络中,是一种优秀的分割网络设计方式。下述代码给出了SegNet的一个简易的结构实现,因为SegNet解码器的特殊性,我们单独定义了一个解码器类,编码器部分直接使用VGG16的预训练权重层,然后在编解码器基础上搭建SegNet并定义前向计算流程。


# 导入PyTorch相关模块
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import models

# 定义SegNet解码器类
class SegNetDec(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
        ]
        layers += [
            nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
        ] * num_layers
        layers += [
            nn.Conv2d(in_channels // 2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)

### 定义SegNet类
class SegNet(nn.Module):
    def __init__(self, classes):
        super().__init__()
    # 编码器使用vgg16预训练权重
        vgg16 = models.vgg16(pretrained=True)
        features = vgg16.features
        self.enc1 = features[0: 4]
        self.enc2 = features[5: 9]
        self.enc3 = features[10: 16]
        self.enc4 = features[17: 23]
        self.enc5 = features[24: -1]
    # 编码器卷积层不参与训练
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.requires_grad = False
    
        self.dec5 = SegNetDec(512, 512, 1)
        self.dec4 = SegNetDec(512, 256, 1)
        self.dec3 = SegNetDec(256, 128, 1)
        self.dec2 = SegNetDec(128, 64, 0)

        self.final = nn.Sequential(*[
            nn.Conv2d(64, classes, 3, padding=1),
            nn.BatchNorm2d(classes),
            nn.ReLU(inplace=True)
        ])
  # 定义SegNet前向计算流程
    def forward(self, x):
        x1 = self.enc1(x)
        e1, m1 = F.max_pool2d(x1, kernel_size=2, stride=2,
 return_indices=True)
        x2 = self.enc2(e1)
        e2, m2 = F.max_pool2d(x2, kernel_size=2, stride=2,
 return_indices=True)
        x3 = self.enc3(e2)
        e3, m3 = F.max_pool2d(x3, kernel_size=2, stride=2,
 return_indices=True)
        x4 = self.enc4(e3)
        e4, m4 = F.max_pool2d(x4, kernel_size=2, stride=2,
 return_indices=True)
        x5 = self.enc5(e4)
        e5, m5 = F.max_pool2d(x5, kernel_size=2, stride=2,
 return_indices=True)

        def upsample(d):
            d5 = self.dec5(F.max_unpool2d(d, m5, kernel_size=2,
 stride=2, output_size=x5.size()))
            d4 = self.dec4(F.max_unpool2d(d5, m4, kernel_size=2,
 stride=2, output_size=x4.size()))
            d3 = self.dec3(F.max_unpool2d(d4, m3, kernel_size=2,
 stride=2, output_size=x3.size()))
            d2 = self.dec2(F.max_unpool2d(d3, m2, kernel_size=2,
 stride=2, output_size=x2.size()))
            d1 = F.max_unpool2d(d2, m1, kernel_size=2, stride=2,
 output_size=x1.size())
            return d1

        d = upsample(e5)
        return self.final(d)

图像分割损失函数loss 总结+代码

汇总语义分割中常用的损失函数:

  1. cross entropy loss
  2. weighted loss
  3. focal loss
  4. dice soft loss
  5. soft iou loss
  6. Tversky Loss
  7. Generalized Dice Loss
  8. Boundary Loss
  9. Exponential Logarithmic Loss
  10. Focal Tversky Loss
  11. Sensitivity Specificity Loss
  12. Shape-aware Loss
  13. Hausdorff Distance Loss

参考论文Medical Image Segmentation Using Deep Learning:A Survey

论文地址:A survey of loss functions for semantic segmentation
代码地址https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions
项目推荐https://github.com/JunMa11/SegLoss

图像分割一直是一个活跃的研究领域,因为它有可能修复医疗领域的漏洞,并帮助大众。在过去的5年里,各种论文提出了不同的目标损失函数,用于不同的情况下,如偏差数据,稀疏分割等。

图像分割可以定义为像素级别的分类任务。图像由各种像素组成,这些像素组合在一起定义了图像中的不同元素,因此将这些像素分类为一类元素的方法称为语义图像分割。在设计基于复杂图像分割的深度学习架构时,通常会遇到了一个至关重要的选择,即选择哪个损失/目标函数,因为它们会激发算法的学习过程。损失函数的选择对于任何架构学习正确的目标都是至关重要的,因此自2012年以来,各种研究人员开始设计针对特定领域的损失函数,以为其数据集获得更好的结果。

这些损失函数可大致分为4类:基于分布的损失函数,基于区域的损失函数,基于边界的损失函数和基于复合的损失函数( Distribution-based,Region-based,  Boundary-based,  and  Compounded)

1、cross entropy loss

用于图像语义分割任务的最常用损失函数是像素级别的交叉熵损失,这种损失会逐个检查每个像素,将对每个像素类别的预测结果(概率分布向量)与我们的独热编码标签向量进行比较。

假设我们需要对每个像素的预测类别有5个,则预测的概率分布向量长度为5:

每个像素对应的损失函数为:

整个图像的损失就是对每个像素的损失求平均值。

特别注意的是,binary entropy loss 是针对类别只有两个的情况,简称 bce loss,损失函数公式为:

#二值交叉熵,这里输入要经过sigmoid处理  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
nn.BCELoss(F.sigmoid(input), target)  
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层  
nn.CrossEntropyLoss(input, target)

2、weighted loss

由于交叉熵损失会分别评估每个像素的类别预测,然后对所有像素的损失进行平均,因此我们实质上是在对图像中的每个像素进行平等地学习。如果多个类在图像中的分布不均衡,那么这可能导致训练过程由像素数量多的类所主导,即模型会主要学习数量多的类别样本的特征,并且学习出来的模型会更偏向将像素预测为该类别。

FCN论文和U-Net论文中针对这个问题,对输出概率分布向量中的每个值进行加权,即希望模型更加关注数量较少的样本,以缓解图像中存在的类别不均衡问题。

比如对于二分类,正负样本比例为1: 99,此时模型将所有样本都预测为负样本,那么准确率仍有99%这么高,但其实该模型没有任何使用价值。

为了平衡这个差距,就对正样本和负样本的损失赋予不同的权重,带权重的二分类损失函数公式如下:

要减少假阴性样本的数量,可以增大 pos_weight;要减少假阳性样本的数量,可以减小 pos_weight。

class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):  
   """  
   Network has to have NO NONLINEARITY!  
   """  
   def __init__(self, weight=None):  
       super(WeightedCrossEntropyLoss, self).__init__()  
       self.weight = weight  
  
   def forward(self, inp, target):  
       target = target.long()  
       num_classes = inp.size()[1]  
  
       i0 = 1  
       i1 = 2  
  
       while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once  
           inp = inp.transpose(i0, i1)  
           i0 += 1  
           i1 += 1  
  
       inp = inp.contiguous()  
       inp = inp.view(-1, num_classes)  
  
       target = target.view(-1,)  
       wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)  
  
       return wce_loss(inp, target)

3、focal loss

上面针对不同类别的像素数量不均衡提出了改进方法,但有时还需要将像素分为难学习和容易学习这两种样本。

容易学习的样本模型可以很轻松地将其预测正确,模型只要将大量容易学习的样本分类正确,loss就可以减小很多,从而导致模型不怎么顾及难学习的样本,所以我们要想办法让模型更加关注难学习的样本。

对于较难学习的样本,将 bce loss 修改为:

其中的 γ 通常设置为2。

通过这种修改,就可以使模型更加专注于学习难学习的样本。

而将这个修改和对正负样本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:

class FocalLoss(nn.Module):  
   """  
   copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py  
   This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in  
   'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'  
       Focal_Loss= -1*alpha*(1-pt)*log(pt)  
   :param num_class:  
   :param alpha: (tensor) 3D or 4D the scalar factor for this criterion  
   :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more  
                   focus on hard misclassified example  
   :param smooth: (float,double) smooth value when cross entropy  
   :param balance_index: (int) balance class index, should be specific when alpha is float  
   :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.  
   """  
  
   def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):  
       super(FocalLoss, self).__init__()  
       self.apply_nonlin = apply_nonlin  
       self.alpha = alpha  
       self.gamma = gamma  
       self.balance_index = balance_index  
       self.smooth = smooth  
       self.size_average = size_average  
  
       if self.smooth is not None:  
           if self.smooth < 0 or self.smooth > 1.0:  
               raise ValueError('smooth value should be in [0,1]')  
  
   def forward(self, logit, target):  
       if self.apply_nonlin is not None:  
           logit = self.apply_nonlin(logit)  
       num_class = logit.shape[1]  
  
       if logit.dim() > 2:  
           # N,C,d1,d2 -> N,C,m (m=d1*d2*...)  
           logit = logit.view(logit.size(0), logit.size(1), -1)  
           logit = logit.permute(0, 2, 1).contiguous()  
           logit = logit.view(-1, logit.size(-1))  
       target = torch.squeeze(target, 1)  
       target = target.view(-1, 1)  
       # print(logit.shape, target.shape)  
       #   
       alpha = self.alpha  
  
       if alpha is None:  
           alpha = torch.ones(num_class, 1)  
       elif isinstance(alpha, (list, np.ndarray)):  
           assert len(alpha) == num_class  
           alpha = torch.FloatTensor(alpha).view(num_class, 1)  
           alpha = alpha / alpha.sum()  
       elif isinstance(alpha, float):  
           alpha = torch.ones(num_class, 1)  
           alpha = alpha * (1 - self.alpha)  
           alpha[self.balance_index] = self.alpha  
  
       else:  
           raise TypeError('Not support alpha type')  
         
       if alpha.device != logit.device:  
           alpha = alpha.to(logit.device)  
  
       idx = target.cpu().long()  
  
       one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()  
       one_hot_key = one_hot_key.scatter_(1, idx, 1)  
       if one_hot_key.device != logit.device:  
           one_hot_key = one_hot_key.to(logit.device)  
  
       if self.smooth:  
           one_hot_key = torch.clamp(  
               one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)  
       pt = (one_hot_key * logit).sum(1) + self.smooth  
       logpt = pt.log()  
  
       gamma = self.gamma  
  
       alpha = alpha[idx]  
       alpha = torch.squeeze(alpha)  
       loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt  
  
       if self.size_average:  
           loss = loss.mean()  
       else:  
           loss = loss.sum()  
       return loss

4、dice soft loss

语义分割任务中常用的还有一个基于 Dice 系数的损失函数,该系数实质上是两个样本之间重叠的度量。此度量范围为 0~1,其中 Dice 系数为1表示完全重叠。Dice 系数最初是用于二进制数据的,可以计算为:

|A∩B| 代表集合A和B之间的公共元素,并且 |A| 代表集合A中的元素数量(对于集合B同理)。

对于在预测的分割掩码上评估 Dice 系数,我们可以将 |A∩B| 近似为预测掩码和标签掩码之间的逐元素乘法,然后对结果矩阵求和。

计算 Dice 系数的分子中有一个2,那是因为分母中对两个集合的元素个数求和,两个集合的共同元素被加了两次。 为了设计一个可以最小化的损失函数,可以简单地使用 1−Dice。 这种损失函数被称为 soft Dice loss,这是因为我们直接使用预测出的概率,而不是使用阈值将其转换成一个二进制掩码。

Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。

对于神经网络的输出,分子与我们的预测和标签之间的共同激活有关,而分母分别与每个掩码中的激活数量有关,这具有根据标签掩码的尺寸对损失进行归一化的效果。

对于每个类别的mask,都计算一个 Dice 损失:

将每个类的 Dice 损失求和取平均,得到最后的 Dice soft loss。

下面是代码实现:

def soft_dice_loss(y_true, y_pred, epsilon=1e-6): 
    ''' 
    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
    Assumes the `channels_last` format.
  
    # Arguments
        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax) 
        epsilon: Used for numerical stability to avoid divide by zero errors
    
    # References
        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation 
        https://arxiv.org/abs/1606.04797
        More details on Dice loss formulation 
        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)
        
        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
    '''
    
    # skip the batch and class axis for calculating Dice score
    axes = tuple(range(1, len(y_pred.shape)-1)) 
    numerator = 2. * np.sum(y_pred * y_true, axes)
    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
    
    return 1 - np.mean(numerator / (denominator + epsilon)) # average over classes and batch

5、soft IoU loss

前面我们知道计算 Dice 系数的公式,其实也可以表示为:

其中 TP 为真阳性样本,FP 为假阳性样本,FN 为假阴性样本。分子和分母中的 TP 样本都加了两次。

IoU 的计算公式和这个很像,区别就是 TP 只计算一次:

和 Dice soft loss 一样,通过 IoU 计算损失也是使用预测的概率值:

其中 C 表示总的类别数。

6、Tversky Loss

论文地址为:https://arxiv.org/pdf/1706.05… 

医学影像中存在很多的数据不平衡现象,使用不平衡数据进行训练会导致严重偏向高精度但低召回率(sensitivity)的预测,这是不希望的,特别是在医学应用中,假阴性比假阳性更难容忍。本文提出了一种基于Tversky指数的广义损失函数,解决了三维全卷积深神经网络训练中数据不平衡的问题,在精度和召回率之间取得了较好的折衷。

Dice loss的正则化版本,以控制假阳性和假阴性对损失函数的贡献,TL被定义为

class TverskyLoss(nn.Module):  
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,  
                square=False):  
       """  
       paper: https://arxiv.org/pdf/1706.05721.pdf  
       """  
       super(TverskyLoss, self).__init__()  
  
       self.square = square  
       self.do_bg = do_bg  
       self.batch_dice = batch_dice  
       self.apply_nonlin = apply_nonlin  
       self.smooth = smooth  
       self.alpha = 0.3  
       self.beta = 0.7  
  
   def forward(self, x, y, loss_mask=None):  
       shp_x = x.shape  
  
       if self.batch_dice:  
           axes = [0] + list(range(2, len(shp_x)))  
       else:  
           axes = list(range(2, len(shp_x)))  
  
       if self.apply_nonlin is not None:  
           x = self.apply_nonlin(x)  
  
       tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)  
  
  
       tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)  
  
       if not self.do_bg:  
           if self.batch_dice:  
               tversky = tversky[1:]  
           else:  
               tversky = tversky[:, 1:]  
       tversky = tversky.mean()  
  
       return -tversky  

7、Generalized Dice Loss

Dice loss虽然一定程度上解决了分类失衡的问题,但却不利于严重的分类不平衡。例如小目标存在一些像素的预测误差,这很容易导致Dice的值发生很大的变化。Sudre等人提出了Generalized Dice Loss (GDL)

GDL优于Dice损失,因为不同的区域对损失有相似的贡献,并且GDL在训练过程中更稳定和鲁棒。

8、Boundary Loss

为了解决类别不平衡的问题,Kervadec等人[95]提出了一种新的用于脑损伤分割的边界损失。该损失函数旨在最小化分割边界和标记边界之间的距离。作者在两个没有标签的不平衡数据集上进行了实验。结果表明,Dice los和Boundary los的组合优于单一组合。复合损失的定义为

其中第一部分是一个标准的Dice los,它被定义为

第二部分是Boundary los,它被定义为

9、Exponential Logarithmic Loss


在(9)中,加权Dice los实际上是得到的Dice值除以每个标签的和,对不同尺度的对象达到平衡。因此,Wong等人结合focal loss [96] 和dice loss,提出了用于脑分割的指数对数损失(EXP损失),以解决严重的类不平衡问题。通过引入指数形式,可以进一步控制损失函数的非线性,以提高分割精度。EXP损失函数的定义为

其中,两个新的参数权重分别用ωdice和ωcross表示。Ldice是指数对数骰子损失,而交叉损失是交叉熵损失

其中x是像素位置,i是标签,l是位置x处的地面真值。pi(x)是从softmax输出的概率值。
在(17)中,fk是标签k出现的频率,该参数可以减少更频繁出现的标签的影响。γDice和γcross都用于增强损失函数的非线性。

10.Focal Tversky Loss

与“Focal loss”相似,后者着重于通过降低易用/常见损失的权重来说明困难的例子。Focal Tversky Loss还尝试借助γ系数来学习诸如在ROI(感兴趣区域)较小的情况下的困难示例,如下所示:

class FocalTversky_loss(nn.Module):  
   """  
   paper: https://arxiv.org/pdf/1810.07842.pdf  
   author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65  
   """  
   def __init__(self, tversky_kwargs, gamma=0.75):  
       super(FocalTversky_loss, self).__init__()  
       self.gamma = gamma  
       self.tversky = TverskyLoss(**tversky_kwargs)  
  
   def forward(self, net_output, target):  
       tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)  
       focal_tversky = torch.pow(tversky_loss, self.gamma)  
       return focal_tversky  

11、Sensitivity Specificity Loss

首先敏感性就是召回率,检测出确实有病的能力:

640-8.png

特异性,检测出确实没病的能力:

640-9.png

而Sensitivity Specificity Loss为:

640-10.png
image.png
class SSLoss(nn.Module):  
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,  
                square=False):  
       """  
       Sensitivity-Specifity loss  
       paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf  
       tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392  
       """  
       super(SSLoss, self).__init__()  
  
       self.square = square  
       self.do_bg = do_bg  
       self.batch_dice = batch_dice  
       self.apply_nonlin = apply_nonlin  
       self.smooth = smooth  
       self.r = 0.1 # weight parameter in SS paper  
  
   def forward(self, net_output, gt, loss_mask=None):  
       shp_x = net_output.shape  
       shp_y = gt.shape  
       # class_num = shp_x[1]  
         
       with torch.no_grad():  
           if len(shp_x) != len(shp_y):  
               gt = gt.view((shp_y[0], 1, *shp_y[1:]))  
  
           if all([i == j for i, j in zip(net_output.shape, gt.shape)]):  
               # if this is the case then gt is probably already a one hot encoding  
               y_onehot = gt  
           else:  
               gt = gt.long()  
               y_onehot = torch.zeros(shp_x)  
               if net_output.device.type == "cuda":  
                   y_onehot = y_onehot.cuda(net_output.device.index)  
               y_onehot.scatter_(1, gt, 1)  
  
       if self.batch_dice:  
           axes = [0] + list(range(2, len(shp_x)))  
       else:  
           axes = list(range(2, len(shp_x)))  
  
       if self.apply_nonlin is not None:  
           softmax_output = self.apply_nonlin(net_output)  
         
       # no object value  
       bg_onehot = 1 - y_onehot  
       squared_error = (y_onehot - softmax_output)**2  
       specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)  
       sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)  
  
       ss = self.r * specificity_part + (1-self.r) * sensitivity_part  
  
       if not self.do_bg:  
           if self.batch_dice:  
               ss = ss[1:]  
           else:  
               ss = ss[:, 1:]  
       ss = ss.mean()  
  
       return ss

12、Log-Cosh Dice Loss

Dice系数是一种用于评估分割输出的度量标准。它也已修改为损失函数,因为它可以实现分割目标的数学表示。但是由于其非凸性,它多次都无法获得最佳结果。Lovsz-softmax损失旨在通过添加使用Lovsz扩展的平滑来解决非凸损失函数的问题。同时,Log-Cosh方法已广泛用于基于回归的问题中,以平滑曲线。

640.png
640-1.png

将Cosh(x)函数和Log(x)函数合并,可以得到Log-Cosh Dice Loss:

640-2.png
def log_cosh_dice_loss(self, y_true, y_pred):  
       x = self.dice_loss(y_true, y_pred)  
       return tf.math.log((torch.exp(x) + torch.exp(-x)) / 2.0)  

13、Hausdorff Distance Loss

Hausdorff Distance Loss(HD)是分割方法用来跟踪模型性能的度量。它定义为:

640-4.png

任何分割模型的目的都是为了最大化Hausdorff距离,但是由于其非凸性,因此并未广泛用作损失函数。有研究者提出了基于Hausdorff距离的损失函数的3个变量,它们都结合了度量用例,并确保损失函数易于处理。

class HDDTBinaryLoss(nn.Module):  
   def __init__(self):  
       """  
       compute haudorff loss for binary segmentation  
       https://arxiv.org/pdf/1904.10030v1.pdf          
       """  
       super(HDDTBinaryLoss, self).__init__()  
  
  
   def forward(self, net_output, target):  
       """  
       net_output: (batch_size, 2, x,y,z)  
       target: ground truth, shape: (batch_size, 1, x,y,z)  
       """  
       net_output = softmax_helper(net_output)  
       pc = net_output[:, 1, ...].type(torch.float32)  
       gt = target[:,0, ...].type(torch.float32)  
       with torch.no_grad():  
           pc_dist = compute_edts_forhdloss(pc.cpu().numpy()>0.5)  
           gt_dist = compute_edts_forhdloss(gt.cpu().numpy()>0.5)  
       # print('pc_dist.shape: ', pc_dist.shape)  
         
       pred_error = (gt - pc)**2  
       dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8)  
  
       dist = torch.from_numpy(dist)  
       if dist.device != pred_error.device:  
           dist = dist.to(pred_error.device).type(torch.float32)  
  
       multipled = torch.einsum("bxyz,bxyz->bxyz", pred_error, dist)  
       hd_loss = multipled.mean()  
  
       return hd_loss

总结:

交叉熵损失把每个像素都当作一个独立样本进行预测,而 dice loss 和 iou loss 则以一种更“整体”的方式来看待最终的预测输出。

这两类损失是针对不同情况,各有优点和缺点,在实际应用中,可以同时使用这两类损失来进行互补。

Deeplab v3

论文地址: https://arxiv.org/abs/1706.05587

Deeplab v3是v2版本的进一步升级,作者们在对空洞卷积重新思考的基础上,进一步对Deeplab系列的基本框架进行了优化,去掉了v1和v2版本中一直坚持的CRF后处理模块,升级了主干网络和ASPP模块,使得网络能够更好地处理语义分割中的多尺度问题。提出Deeplab v3的论文为Rethinking Atrous Convolution for Semantic Image Segmentation,是Deeplab系列后期网络的代表模型。

DeepLab V3的改进主要包括以下几方面:

1)提出了更通用的框架,适用于任何网络

2)复制了ResNet最后的block,并级联起来

3)在ASPP中使用BN层

4)去掉了CRF

随着语义分割的发展,逐渐有两大问题亟待解决:一个是连续的池化和卷积步长导致的下采样图像信息丢失问题,这个问题已经通过空洞卷积扩大感受野得到比较好的处理;另外一个则是多尺度和利用上下文信息问题。论文中分别回顾了四种基于多尺度和上下文信息进行语义分割的方法,如图1所示,包括图像金字塔、编解码架构、深度空洞卷积网络以及空间金字塔池化,这四种方法各有优缺点,ASPP可以算是深度空洞卷积和空间金字塔池化的一种结合,Deeplab v3在v2的ASPP基础上,进一步探索了空洞卷积在多尺度和上下文信息中的作用。

Deeplab v3可作为通用框架融入到任意网络结构中,具体地,以串行方式设计空洞卷积模块,复制ResNet的最后一个卷积块,并将复制后的卷积块以串行方式进行级联,如图2所示。DeepLab V3将空洞卷积应用在级联模块。具体来说,我们取ResNet中最后一个block,在下图中为block4,并在其后面增加级联模块。

在卷积块串行级联基础上,Deeplab v3又对ASPP模块进行并行级联,v3对ASPP模块进行了升级,相较于v2版本加入了批归一化(Batch Normalization,BN),通过组织不同的空洞扩张率的卷积块,同时加入图像级特征,能够更好地捕捉多尺度上下文信息,并且也能够更容易训练,如图3所示。

1)ASPP中应用了BN层

2)随着采样率的增加,滤波器中有效的权重减少了(有效权重减少,难以捕获原距离信息,这要求合理控制采样率的设置)

3)使用模型最后的特征映射的全局平均池化(为了克服远距离下有效权重减少的问题)

总结来看,Deeplab v3更充分的利用了空洞卷积来获取大范围的图像上下文信息。具体包括:多尺度信息编码、带有逐步翻倍的空洞扩张率的级联模块以及带有图像级特征的ASPP模块。实验结果表明,该模型在 PASCAL VOC数据集上相较于v2版本有了显着进步,取得了当时SOTA精度水平。

Deeplab v3的PyTorch实现可参考:

https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py

代码实现:

class DeeplabV3(ResNet):

    def __init__(self, n_class, block, layers, pyramids, grids, output_stride=16):
        self.inplanes = 64
        super(DeeplabV3, self).__init__()
        if output_stride == 16:
            strides = [1, 2, 2, 1]
            rates = [1, 1, 1, 2]
        elif output_stride == 8:
            strides = [1, 2, 1, 1]
            rates = [1, 1, 2, 2]
        else:
            raise NotImplementedError

        # Backbone Modules
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)   # h/4, w/4

        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], rate=rates[0]) # h/4, w/4
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], rate=rates[1]) # h/8, w/8
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], rate=rates[2]) # h/16,w/16
        self.layer4 = self._make_MG_unit(block, 512, blocks=grids, stride=strides[3], rate=rates[3])# h/16,w/16

        # Deeplab Modules
        self.aspp1 = ASPP_module(2048, 256, rate=pyramids[0])  
        self.aspp2 = ASPP_module(2048, 256, rate=pyramids[1])
        self.aspp3 = ASPP_module(2048, 256, rate=pyramids[2])
        self.aspp4 = ASPP_module(2048, 256, rate=pyramids[3])

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(2048, 256, kernel_size=1, stride=1, bias=False),
                                             nn.BatchNorm2d(256),
                                             nn.ReLU())

        # get result features from the concat
        self._conv1 = nn.Sequential(nn.Conv2d(1280, 256, kernel_size=1, stride=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    nn.ReLU())

        # generate the final logits
        self._conv2 = nn.Conv2d(256, n_class, kernel_size=1, bias=False)

        self.init_weight()

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)

        # image-level features
        x5 = self.global_avg_pool(x)
        x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self._conv1(x)
        x = self._conv2(x)

        x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)

        return x

其中重要的_makelayer, _make_MG_unit和ASSP模块实现如下:

def _make_layer(self, block, planes, blocks, stride=1, rate=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, rate, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, rate=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, rate=blocks[0] * rate, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1, rate=blocks[i] * rate))

        return nn.Sequential(*layers)


class ASPP_module(nn.Module):
    def __init__(self, inplanes, planes, rate):
        super(ASPP_module, self).__init__()
        if rate == 1:
            kernel_size = 1
            padding = 0
        else:
            kernel_size = 3
            padding = rate
        self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=rate, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_convolution(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

训练策略:

Crop size:

  • 为了大采样率的空洞卷积能够有效,需要较大的图片大小;否则,大采样率的空洞卷积权值就会主要用于padding区域。
  • 在Pascal VOC 2012数据集的训练和测试中我们采用了513的裁剪尺寸。

Batch normalization:

  • 我们在ResNet之上添加的模块都包括BN层
  • 当output_stride=16时,采用batchsize=16,同时BN层的参数做参数衰减0.9997。
  • 在增强的数据集上,以初始学习率0.007训练30K后,冻结BN层参数,然后采用output_stride=8,再使用初始学习率0.001在PASCAL官方的数据集上训练30K。
  • 训练output_stride=16比output_stride=8要快很多,因为其中间的特征映射在空间上小四倍。但output_stride=16在特征映射上相对粗糙,快是因为牺牲了精度。

Upsampling logits:

  • 在先前的工作上,我们是将output_stride=8的输出与Ground Truth下采样8倍做比较。
  • 现在我们发现保持Ground Truth更重要,故我们是将最终的输出上采样8倍与完整的Ground Truth比较。

Data augmentation:

在训练阶段,随机缩放输入图像(从0.5到2.0)和随机左-右翻转

Deeplab v3+

论文地址 https://arxiv.org/abs/1706.05587

Deeplab v3+是Deeplab系列最后一个网络结构,也是基于空洞卷积和多尺度系列模型的集大成者。相较于Deeplab v3,v3+版本参考了UNet系列网络,对基于空洞卷积的Deeplab网络引入了编解码结构,一定程度上来讲,Deeplab v3+是编解码和多尺度这两大系列网络的一个大融合,在很长一段时间内代表了自然图像语义分割的SOTA水平的分割模型。提出Deeplab v3+的论文为Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation,至今仍然是最常用的一个语义分割网络模型。

对于语义分割问题,尽管各种网络模型很多,但Deeplab v3+的作者们认为迄今为止仅有两大主流设计:一个是以UNet为代表的编解码结构,另一个就是以Deeplab为代表的ASPP和多尺度设计,前者以获取图像中目标对象细致的图像边界见长,后者则更擅长捕捉图像中丰富的上下文多尺度信息。鉴于此,Deeplab v3+将编解码结构引入到v3网络中,Deeplab v3直接作为编码器,然后再加入解码器设计,这样就构成了一个SPP(空间金字塔池化)+Atrous Conv(空洞卷积)+Encoder-Decoder(编解码)包含众多丰富元素的组合结构。这三种结构如下图所示。

完整的Deeplab v3+网络结构如下图所示。可以看到,输入图像先进入有深度空洞卷积构成的编码器部分中,由并行的ASPP模块构建并且融入图像级的池化特征,然后再进行合并并通过1*1卷积降低通道数后得到编码器的输出,编码器部分没有做特别的改动,直接使用了Deeplab v3的结构作为编码器,旨在提取图像的多尺度上下文信息。解码器部分则是做了一些特别的设计,编码器输出先经过4倍双线性插值上采样,同时也从编码器部分链接浅层的图像特征到解码器部分,并经1*1卷积降维后与4倍上采样的输出进行拼接,最后经过一个3*3卷积后再经过一次4倍上采样后得到最终的分割输出结果。

另外,Deeplab v3+除了延用之前的ResNet系列作为backbone之外,也尝试了以轻量级着称的Xception网络,具体做法就是将深度可分离卷积(Depth separable convolution)引入到空洞卷积中,能够极大的减少计算量,虽然有一定的精度损失,但在追求速度性能上不失为一种非常好的选择。下面简单介绍一下深度可分离卷积。

从维度的角度看,卷积核可以看成是一个空间维(宽和高)和通道维的组合,而卷积操作则可以视为空间相关性和通道相关性的联合映射。从Inception的1*1卷积来看,卷积中的空间相关性和通道相关性是可以解耦的,将它们分开进行映射,可能会达到更好的效果。

深度可分离卷积是在1*1卷积基础上的一种创新。主要包括两个部分:深度卷积和1*1卷积。深度卷积的目的在于对输入的每一个通道都单独使用一个卷积核对其进行卷积,也就是通道分离后再组合。1*1卷积的目的则在于加强深度。下面以一个例子来看一下深度可分离卷积:假设我们用128个3*3*3的滤波器对一个7*7*3的输入进行卷积,可得到5*5*128的输出,其计算量为5*5*128*3*3*3=86400,如下图所示。

然后我们看如何使用深度可分离卷积来实现同样的结果。深度可分离卷积的第一步是深度卷积。这里的深度卷积,就是分别用3个3*3*1的滤波器对输入的3个通道分别做卷积,也就是说要做3次卷积,每次卷积都有一个5*5*1的输出,组合在一起便是5*5*3的输出。现在为了拓展深度达到128,我们需要执行深度可分离卷积的第二步:1*1卷积。现在我们用128个1*1*3的滤波器对5*5*3进行卷积,就可以得到5*5*128的输出。完整过程如下图所示。

看一下深度可分离卷积的计算量如何。第一步深度卷积的计算量:5*5*1*3*3*1*3=675。第二步11卷积的计算量:5*5*128*1*1*3=9600,合计计算量为10275次。可见,相同的卷积计算输出,深度可分离卷积要比常规卷积节省12倍的计算成本。典型的应用深度可分离卷积的网络模型包括Xception和MobileNet等。本质上而言,Xception就是应用了深度可分离卷积的Inception网络。

Deeplab v3+在PASCAL VOC和Cityscapes等公开数据集上均取得了SOTA的结果,即使在深度学习语义分割发展日新月异发展的今天,Deeplab v3+仍然不失为一个非常好的语义分割解决方案。

关于Deeplab系列各版本的技术要点总结如下表所示。

Deeplab v2

DeepLabv2:
DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs

TPAMI 2018

Deeplab v2 严格上算是Deeplab v1版本的一次不大的更新,在v1的空洞卷积和CRF基础上,重点关注了网络对于多尺度问题的适用性。多尺度问题一直是目标检测和语义分割任务的重要挑战之一,以往实现多尺度的惯常做法是对同一张图片进行不同尺寸的缩放后获取对应的卷积特征图,然后将不同尺寸的特征图分别上采样后再融合来获取多尺度信息,但这种做法最大的缺点就是计算开销太大。Deeplab v2借鉴了空间金字塔池化(Spatial Pyramaid Pooling, SPP)的思路,提出了基于空洞卷积的空间金字塔池化(Atrous Spatial Pyramaid Pooling, ASPP),这也是Deeplab v2最大的亮点。提出Deeplab v2的论文为DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs,是Deeplab系列网络中前期结构的重要代表。

ASPP来源于R-CNN(Regional CNN)目标检测领域中SPP结构,该方法表明任意尺度的图像区域可以通过对单一尺度提取的卷积特征进行重采样而准确有效地分类。ASPP在其基础上将普通卷积改为空洞卷积,通过使用多个不同扩张率且并行的空洞卷积进行特征提取,最后在对每个分支进行融合。ASPP结构如下图所示。

除了ASPP之外,Deeplab v2还将v1中VGG-16的主干网络换成了ResNet-101,算是对编码器的一次升级,使其具备更强的特征提取能力。Deeplab v2在PASCAL VOC和Cityscapes等语义分割数据集上均取得了当时的SOTA结果。关于ASPP模块的一个简单实现参考如下代码所示,先是分别定义了ASPP的卷积和池化方法,然后在其基础上定义了ASPP模块。


### 定义ASPP卷积方法
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation,
                      dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        super(ASPPConv, self).__init__(*modules)

  ### 定义ASPP池化方法
class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode='bilinear',
 align_corners=False)

### 定义ASPP模块
class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)))

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),)

  # ASPP前向计算流程
    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

下图是Deeplab v2在Cityscapes数据集上的分割效果:

Deeplab v1

DeepLabv1:
Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs

ICLR 2015

在语义分割发展早期,一些研究观点认为将CNN用于图像分割主要存在两个问题:一个是下采样导致的信息丢失问题,另一个则是CNN的空间不变性问题,这与CNN本身的特性有关,这种空间不变性有利于图像分类但却不利于图像分割中的像素定位。从多尺度和上下文信息的角度来看,这两个问题是导致FCN分割效果有限的重要原因。因而,相关研究针对上述两个问题提出了Deeplab v1网络,通过在常规卷积中引入空洞(Atrous)和对CNN分割结果补充CRF作为后处理来优化分割效果。提出Deeplab v1的论文为Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs,是Deeplab系列的开篇之作。

针对第一个问题,池化下采样操作引起信息丢失,Deeplab v1给出的解决方案算是另辟蹊径。常规卷积中,使用池化下采样的主要目的是增大每个像素的感受野,但在Deeplab v1中,作者们的想法是可以不用池化也可以增大像素的感受野,尝试在卷积操作本身上重新进行设计。在Deeplab v1,一种在常规卷积核中插入空洞的设计被提出,相较于池化下采样,空洞卷积能够在不降低图像分辨率的情况下扩大像素感受野,从而就避免了信息损失的问题

空洞卷积(Dilated/Atrous Convolution)也叫扩张卷积或者膨胀卷积,字面意思上来说就是在卷积核中插入空洞,起到扩大感受野的作用。空洞卷积的直接做法是在常规卷积核中填充0,用来扩大感受野,且进行计算时,空洞卷积中实际只有非零的元素起了作用。假设以一个变量a来衡量空洞卷积的扩张系数,则加入空洞之后的实际卷积核尺寸与原始卷积核尺寸之间的关系:

K=k+(k-1)(a-1)

其中为k原始卷积核大小,a为空洞率(Dilation Rate),K为经过扩展后实际卷积核大小。除此之外,空洞卷积的卷积方式跟常规卷积一样。当a=1时,空洞卷积就退化为常规卷积。a=1,2,4时,空洞卷积示意图如下图所示。

对于语义分割而言,空洞卷积主要有三个作用:

第一是扩大感受野,具体前面已经说的比较多了,这里不做重复。但需要明确一点,池化也可以扩大感受野,但空间分辨率降低了,相比之下,空洞卷积可以在扩大感受野的同时不丢失分辨率,且保持像素的相对空间位置不变。简单而言就是空洞卷积可以同时控制感受野和分辨率。

第二就是获取多尺度上下文信息。当多个带有不同空洞率的空洞卷积核叠加时,不同的感受野会带来多尺度信息,这对于分割任务是非常重要的。

第三就是可以降低计算量,不需要引入额外的参数,如图4-13所示,实际卷积时只有带有红点的元素真正进行计算。

针对第二个问题,Deeplab v1通过引入全连接的CRF来对CNN的粗分割结果进行优化。CRF作为一种经典的概率图模型,可用于图像像素之间的关系描述,在传统图像处理中主要用于图像平滑处理。但对于CNN分割问题来说,使用短程的CRFs可能会于事无补,因为分割问题的目标是恢复图像的局部细节信息,而不是对图像做平滑处理。所以Deeplab v1提出的解决方案叫做全连接CRF(Fully Connected CRF)。

条件随机场可以优化物体的边界,平滑带噪声的分割结果,去掉物体中间的预测的孔洞,使得分割结果更加准确。

CRF是一种经典的概率图模型,简单而言就是给定一组输入序列的条件下,求另一组输出序列的条件概率分布模型,CRF在自然语言处理领域有着广泛应用。CRF在语义分割后处理中用法的基本思路如下:对于FCN或者其他分割网络的粗粒度分割结果而言,每个像素点i具有对应的类别标签x_i和观测值y_i,以每个像素为节点,以像素与像素之间的关系作为边即可构建一个CRF模型。在这个CRF模型中,我们通过观测变量y_i来预测像素i对应的标签值x_i。CRF用于像素预测的结构如下图所示。

全连接CRF使用的能量函数为:

Deeplab v1模型流程如下图所示。输入图像经过深度卷积网络(DCNN)后生成浓缩的、粗粒度的语义特征图,再经过双线性插值上采样后形成粗分割结果,最后经全连接的CRF后处理生成最终的分割结果图。

下图是Deeplab v1在无CRF和有CRF后处理的分割效果图:

总的来看,Deeplab v1有如下几个有点:

(1)速度快。基于空洞卷积的CNN分割网络,能够保证8秒每帧(Frame Per Second,FPS)的推理速度,后处理的全连接CRF也仅需要0.5秒。

(2)精度高。Deeplab v1在当时取得了在PASCAL VOC数据集上SOTA的分割模型表现,准确率超过此前SOTA的7.2%。

(3)简易性。DCNN和CRF均为成熟的算法模块,将其进行简单的级联即可在当前的语义分割模型上取得好的效果。

Deeplab v1 PyTorch参考实现代码如下:

https://github.com/wangleihitcs/DeepLab-V1-PyTorch