小红书推荐系统 —公开课

推荐系统课件:https://github.com/wangshusen/RecommenderSystem

工业界的推荐系统

这门课程结合小红书的业务场景和内部实践,讲解主流的工业界推荐系统技术。

  1. 概要
  2. 召回
  3. 排序
  4. 交叉结构
  5. 用户行为序列建模
    • 用户行为序列特征
    • DIN 模型
    • SIM 模型
  6. 重排
    • 多样性 MMR
    • 多样性 DPP
    • 多样性 MGS
  7. 物品冷启动

Real-ESRGAN 超分辨网络

论文:Real-ESRGAN: TrainingReal-World Blind Super-Resolution with Pure Synthetic Data

代码:https://github.com/xinntao/Real-ESRGAN

Real-ESRGAN 的目标是开发出实用的图像/视频修复算法。
在 ESRGAN 的基础上使用纯合成的数据来进行训练,以使其能被应用于实际的图片修复的场景(顾名思义:Real-ESRGAN)。

  1. 目标:解决真实场景下的图像模糊问题。
  2. 数据集的构建:模糊核、噪声、尺寸缩小、压缩四种操作的随机顺序。
  3. 超分网络backbone:ESRGAN的生成网络+U-Net discriminator判别器。
  4. 损失函数:L1 loss,perceptual loss,生成对抗损失。
  5. 主要对比方法是:RealSR、ESRGAN、BSRGAN、DAN、CDC。

创新点

  1. 提出了新的构建数据集的方法,用高阶处理,增强降阶图像的复杂度。
  2. 构造数据集时引入sinc filter,解决了图像中的振铃和过冲现象。
  3. 替换原始ESRGAN中的VGG-discriminator,使用U-Net discriminator,以增强图像的对细节上的对抗学习。
  4. 引入spectral normalization以稳定由于复杂数据集和U-Net discriminator带来的训练不稳定情况。

数据集构建

在讨论数据集的构建前,作者详细讨论了造成图像模糊的原因,例如:年代久远的手机、传感器噪声、相机模糊、图像编辑、图像在网络中的传输、JPEG压缩以及其它噪声。原文如下:

For example, when we take a photo with our cellphones, the photos may have several degradations, such as camera blur, sensor noise, sharpening artifacts, and JPEG compression. We then do some editing and upload to a social media APP, which introduces further compression and unpredictable noises.

所以作者针对以上问题,提出了High-order降级模型。先面我们先介绍first-order降级模型,然后就很好理解High-order降级模型了。

First-order

First-order降级模型其实就是常规的降级模型,如上式所示,按顺序执行上述操作。

x代表降级后的图像,D代表降级函数,y代表原始图像;
k代表模糊核,r代表缩小比例,n代表加入的噪声,JPEG代表进行压缩。

每一种降级方法又有多种降级方案可以选择,如下图所示:

对于模糊核k,本方法使用各项同性(isotropic)和各向异性(anisotropic)的高斯模糊核。关于sinc filter会在下文中提到。

对于缩小操作r,常用的方法又双三次插值、双线性插值、区域插值—由于最近邻插值需要考虑对齐问题,所以不予以考虑。在执行缩小操作时,本方法从提到的3种插值方式中随机选择一种。

对于加入噪声操作n,本方法同时加入高斯噪声和服从泊松分布的噪声。同时,根据待超分图像的通道数,加入噪声的操作可以分为对彩色图像添加噪声和对灰度图像添加噪声。

JPEG压缩,本方法通过从[0, 100]范围中选择压缩质量,对图像进行JPEG压缩,其中0表示压缩后的质量最差,100表示压缩后的质量最好。JPEG压缩方法点此处

  • High-order

First-order由于使用相对单调的降级方法,其实很难模仿真实世界中的图像低分辨模糊情况。因此,作者提出的High-order其实是为了使用更复杂的降级方法,更好的模拟真实世界中的低分辨模糊情况,从而达到更好的学习效果。一阶降级模型构建的数据集训练结果如下:

高阶降级模型公式如下:

上式,其实就是对First-order进行多次重复操作,也就是每一个D都是执行一次完整的First-order降级。作者通过实验得出,当执行2次First-order时生成的数据集训练效果最好。所以,High-order的pipeline如下:

  • sinc filter

为了解决超分图像的振铃和过冲现象(振铃过冲在图像处理中很常见,此处不过多介绍),作者提出了在构建降级模型中增加sinc filter的操作。先来看一下振铃和过冲伪影的效果:

上图表示实际中的振铃和过冲伪影现象,下图表示通过对sinc filter设置不同的因子人工模仿的振铃和过冲伪影现象。过于如何构造sinc filter,详细细节建议看原文。

sinc filter在两个位置进行设置,一是在每一阶的模糊核k的处理中,也就是在各项同性和各项异性的高斯模糊之后,设置sinc filter;二是在最后一阶的JPEG压缩时,设置sinc filter,其中最后一阶的JPEG和sinc filter执行先后顺序是随机的。

网络结构

  • 生成网络

生成网络是ESRGAN的生成网络,基本没变,只是在功能上增加了对x2和x1倍的图像清晰度提升。对于x4倍的超分辨,网络完全按照ESRGAN的生成器执行;而对于X2和X1倍的超分辨,网络先进行pixel-unshuffle(pixel-shuffl的反操作,pixel-shuffle可理解为通过压缩图像通道而对图像尺寸进行放大),以降低图像分辨率为前提,对图像通道数进行扩充,然后将处理后的图像输入网络进行超分辨重建。举个例子:对于一幅图像,若只想进行x2倍放大变清晰,需先通过pixel-unshuffle进行2倍缩小,然后通过网络进行4倍放大。

对抗网络

由于使用的复杂的构建数据集的方式,所以需要使用更先进的判别器对生成图像进行判别。之前的ESRGAN的判别器更多的集中在图像的整体角度判别真伪,而使用U-Net 判别器可以在像素角度,对单个生成的像素进行真假判断,这能够在保证生成图像整体真实的情况下,注重生成图像细节。

  • 光谱标准正则化

通过加入这一操作,可以缓和由于复杂数据集合复杂网络带来的训练不稳定问题。

训练

分为两步:

  1. 先通过L1 loss,训练以PSRN为导向的网络,获得的模型称为Real-ESRNet。
  2. 以Real-ESRNet的网络参数进行网络初始化,损失函数设置为 L1 loss、perceptual loss、 GAN loss,训练最终的网络Real-ESRGAN。

ESRGAN图像超分辨

论文:https://arxiv.org/abs/1809.00219

github: https://github.com/xinntao/ESRGAN

ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks发表于ECCV 2018 的 Workshops,在SRGAN的基础上进行了改进,包括改进网络的结构,判决器的判决形式,以及更换了一个用于计算感知域损失的预训练网络

超分辨率生成对抗网络(SRGAN)是一项开创性的工作,能够在单一图像超分辨率中生成逼真的纹理。这项工作发表于CVPR 2017,

文章链接:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

但是,放大后的细节通常伴随着令人不快的伪影。为了更进一步地提升视觉质量,作者仔细研究了SRGAN的三个关键部分:1.网络结构 2.对抗性损失 3.感知域损失;并对每一项进行改进,得到ESRGAN。具体而言,文章提出了一种Residual-in-Residual Dense Block (RRDB)的网络单元,在这个单元中,去掉了BN(Batch Norm)层。此外,作者借鉴了relativistic GAN的想法,让判别器预测图像的真实性而不是图像“是否是fake图像”。最后,文章对感知域损失进行改进,使用激活前的特征,这样可以为亮度一致性和纹理恢复提供更强的监督。在这些改进的帮助下,ESRGAN得到了更好的视觉质量以及更逼真和自然的纹理。

在纹理和细节上,ESRGAN都优于SRGAN

SRGAN的思考与贡献

现有的超分辨率网络在不同的网络结构设计以及训练策略下,超分辨的效果得到了很大的提升,特别是PSNR指标。但是,基于PSNR指标的模型会倾向于生成过度平滑的结果,这些结果缺少必要的高频信息。PSNR指标与人类观察者的主观评价从根本上就不统一。

一些基于感知域信息驱动的方法已经提出来用于提升超分辨率结果的视觉质量。例如,感知域的损失函数提出来用于在特征空间(instead of 像素空间)中优化超分辨率模型;生成对抗网络通过鼓励网络生成一些更接近于自然图像的方法来提升超分辨率的质量;语义图像先验信息用于进一步改善恢复的纹理细节。

通过结合上面的方法,SRGAN模型极大地提升了超分辨率结果的视觉质量。但是SRGAN模型得到的图像和GT图像仍有很大的差距。

ESRGAN的改进

文章对这三点做出改进:1.网络的基本单元从基本的残差单元变为Residual-in-Residual Dense Block (RRDB);2.GAN网络改进为Relativistic average GAN (RaGAN);3.改进感知域损失函数,使用激活前的VGG特征,这个改进会提供更尖锐的边缘和更符合视觉的结果。

网络结构及思想

生成器部分

首先,作者参考SRResNet结构作为整体的网络结构,SRResNet的基本结构如下:

SRResNet基本结构

为了提升SRGAN重构的图像的质量,作者主要对生成器G做出如下改变:1.去掉所有的BN层;2.把原始的block变为Residual-in-Residual Dense Block (RRDB),这个block结合了多层的残差网络和密集连接。

如下图所示:

RRDB

思想:

BN层的影响

对于不同的基于PSNR的任务(包括超分辨率和去模糊)来说,去掉BN层已经被证明会提高表现和减小计算复杂度。BN层在训练时,使用一个batch的数据的均值和方差对该batch特征进行归一化,在测试时,使用在整个测试集上的数据预测的均值和方差。当训练集和测试集的统计量有很大不同的时候,BN层就会倾向于生成不好的伪影,并且限制模型的泛化能力。作者发现,BN层在网络比较深,而且在GAN框架下进行训练的时候,更会产生伪影。这些伪影偶尔出现在迭代和不同的设置中,违反了对训练稳定性能的需求。所以为了稳定的训练和一致的性能,作者去掉了BN层。此外,去掉BN层也能提高模型的泛化能力,减少计算复杂度和内存占用。

Trick:

除了上述的改进,作者也使用了一些技巧来训练深层网络:1.对残差信息进行scaling,即将残差信息乘以一个0到1之间的数,用于防止不稳定;2.更小的初始化,作者发现当初始化参数的方差变小时,残差结构更容易进行训练。

判别器部分

除了改进的生成器,作者也基于Relativistic GAN改进了判别器。判别器 D 使用的网络是 VGG 网络,SRGAN中的判别器D用于估计输入到判别器中的图像是真实且自然图像的概率,而Relativistic判别器则尝试估计真实图像相对来说比fake图像更逼真的概率。

具体而言,作者把标准的判别器换成Relativistic average Discriminator(RaD),所以判别器的损失函数定义为:

求均值的操作是通过对mini-batch中的所有数据求平均得到的,xf是原始低分辨图像经过生成器以后的图像。

可以观察到,对抗损失包含了xr和xf,所以这个生成器受益于对抗训练中的生成数据和实际数据的梯度,这种调整会使得网络学习到更尖锐的边缘和更细节的纹理。

感知域损失

文章也提出了一个更有效的感知域损失,使用激活前的特征(VGG16网络)。

感知域的损失当前是定义在一个预训练的深度网络的激活层,这一层中两个激活了的特征的距离会被最小化。与此相反,文章使用的特征是激活前的特征,这样会克服两个缺点。第一,激活后的特征是非常稀疏的,特别是在很深的网络中。这种稀疏的激活提供的监督效果是很弱的,会造成性能低下;第二,使用激活后的特征会导致重建图像与GT的亮度不一致。

使用激活前与激活后的特征的比较,(a)亮度;(b)细节

作者对使用的感知域损失进行了探索。与目前多数使用的用于图像分类的VGG网络构建的感知域损失相反,作者提出一种更适合于超分辨的感知域损失,这个损失基于一个用于材料识别的VGG16网络(MINCNet),这个网络更聚焦于纹理而不是物体。尽管这样带来的增益很小,但作者仍然相信,探索关注纹理的感知域损失对超分辨至关重要。

损失函数

经过上面对网络模块的定义和构建以后,再定义损失函数,就可以进行训练了。

对于生成器G,它的损失函数为:

代码解析:

https://zhuanlan.zhihu.com/p/54473407?utm_id=0

3.提取感知域损失的网络(Perceptual Network)

文章使用了一个用于材料识别的VGG16网络(MINCNet)来提取感知域特征,定义如下:

class MINCNet(nn.Module):
    def __init__(self):
        super(MINCNet, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
        self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
        self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)

    def forward(self, x):
        out = self.ReLU(self.conv11(x))
        out = self.ReLU(self.conv12(out))
        out = self.maxpool1(out)
        out = self.ReLU(self.conv21(out))
        out = self.ReLU(self.conv22(out))
        out = self.maxpool2(out)
        out = self.ReLU(self.conv31(out))
        out = self.ReLU(self.conv32(out))
        out = self.ReLU(self.conv33(out))
        out = self.maxpool3(out)
        out = self.ReLU(self.conv41(out))
        out = self.ReLU(self.conv42(out))
        out = self.ReLU(self.conv43(out))
        out = self.maxpool4(out)
        out = self.ReLU(self.conv51(out))
        out = self.ReLU(self.conv52(out))
        out = self.conv53(out)
        return out

再引入预训练参数,就可以进行特征提取:

class MINCFeatureExtractor(nn.Module):
    def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
                device=torch.device('cpu')):
        super(MINCFeatureExtractor, self).__init__()

        self.features = MINCNet()
        self.features.load_state_dict(
            torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
        self.features.eval()
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        output = self.features(x)
        return output

网络插值思想

为了平衡感知质量和PSNR等评价值,作者提出了一个灵活且有效的方法—网络插值。具体而言,作者首先基于PSNR方法训练的得到的网络G_PSNR,然后再用基于GAN的网络G_GAN进行finetune。

然后,对这两个网络相应的网络参数进行插值得到一个插值后的网络G_INTERP:

这样就可以通过 α 值来调整效果

训练细节

训练细节

放大倍数:4,mini-batch:16

通过Matlab的bicubic函数对HR图像进行降采样得到LR图像。

HR patch大小:128×128(实验发现使用大的patch时,训练一个深层网络效果会更好,因为一个增大的感受域会帮助模型捕捉更具有语义的信息)

训练过程:

1.训练一个基于PSNR指标的模型(L1 Loss)

初始化学习率:2×1e-4

每200000个mini-batch学习率除以2

2.以1中训练的模型作为生成器的初始化

λ=5×10−3,η=0.01,β=0.2 (残差scaling系数)

初始学习率:1e-4,并在50k,100k,200k,300k迭代后减半。

一个基于像素损失函数进行优化的预训练模型会帮助基于GAN的模型生成更符合视觉的结果,原因如下:1.可以避免生成器不希望的局部最优;2.再预训练以后,判别器所得到的输入图像的质量是相对较好的,而不是完全初始化的图像,这样会使判别器更关注到纹理的判别。

优化器:Adam( β1=0.9,β2=0.999 );交替更新生成器和判别器,直到收敛。

生成器的设置:1.16层(基本的残差结构) 2.23层(RDDB)

数据集:DIV2K,Flickr2K,OST;(有丰富纹理信息的数据集会是模型产生更自然的结果)

可以看到,ESRGAN得到的图像PSNR值不高,但是从视觉效果上看会更好,Percpetual Index值更小(越小越好),而且ESRGAN在 PIRM-SR 竞赛上也获得了第一名(在Percpetual Index指标上)。

经过实验以后,作者得出结论:

1.去掉BN:并没有降低网络的性能,而且节省了计算资源和内存占用。而且发现当网络变深变复杂时,带BN层的模型更倾向于产生影响视觉效果的伪影。

2.使用激活前的特征:得到 的图像的亮度更准确,而且可以产生更尖锐的边缘和更丰富的细节。

3.RaGAN:产生更尖锐的边缘和更丰富的细节。

4.RDDB:更加提升恢复得到的纹理(因为深度模型具有强大的表示能力来捕获语义信息),而且可以去除噪声。

网络插值实验

为了平衡视觉效果和PSNR等性能指标,作者对网络插值参数 α 的取值进行了实验,结果如下:

总结

文章提出的ESRGAN在SRGAN的基础上做出了改进,包括去除BN层,基本结构换成RDDB,改进GAN中判别器的判别目标,以及使用激活前的特征构成感知域损失函数,实验证明这些改进对提升输出图像的视觉效果都有作用。此外,作者也使用了一些技巧来提升网络的性能,包括对残差信息的scaling,以及更小的初始化。最后,作者使用了一种网络插值的方法来平衡输出图像的视觉效果和PSNR等指标值。

SRGAN —使用GAN进行图像高分辨的开山之作

论文: https://arxiv.org/abs/1609.04802(CVPR 2017)

这篇文章第一次将将生成对抗网络用在了解决超分辨率问题上。将GAN引入SR领域

之前超分的研究虽然主要聚焦于“恢复细粒度的纹理细节”这个问题上,但将问题一直固定在最大化峰值信噪比(Peak Signal-to-Noise Ratio, PSNR)上,等价于 最小化与GT图像的均方重建误差(mean squared reconstruction error, MSE)。

而这也就导致:

  1. 高频细节(high-frequency details) 的丢失,整体图像过于平滑/模糊;
  2. 与人的视觉感知不一致,超分图像的精确性与人的期望不匹配(人可能更关注前景,而对背景清晰度要求不高)。
中间蓝色框是基于MSE所学到的超分图像所在像素空间,红色框是真实超分图像所在的像素空间流形,基于GAN的方法驱动重构图像往真实图像像素流形区域靠近,从而感知上更真实可信

从而提出3个改进:

  1. 新的backbone:SRResNet;
  2. GAN-based network 及 新的损失函数:
  3. adversarial loss:提升真实感(photo-realistic natural images);
  4. content loss:获取HR image和生成图像的感知相似性(perceptual similarity),而不只是像素级相似性(pixel similarity);或者说特征空间的相似性而不是像素空间的相似性。
  5. 使用主观评估手段:MOS,更加强调人的感知。

SRGAN算法改进细节:

生成网络是新的结构SRResNet(横跨主干网络的skip connection操作很关键),卷积核尺寸k,输出通道数n,步长s。头部后续接了两个输出通道数为256=64*4的卷积块,因为其中PixelShuffle*2会将feature map转化 (64, H*2, W*2)的输出(sub-pixel convolution操作),这样总共upscale *4
  • SRResNet和GAN-based Network

上图就是新的网络结构,G网络是SRResNet,论文使用了16个residual blocks;D网络为8次卷积操作(4次步长为2)+2次全连接层的VGG网络。

损失函数

生成网络的损失函数为:

包含:

论文对VGG高层特征和低层特征分别做了实验,最终选择可能关注更多图像内容的高层特征作为论文实验的损失特征图。

判别网络的损失函数为二分类交叉熵损失函数:

SRGAN实验设置

使用数据集Set5,Set14,BSD100,BSD300测试集对训练模型进行实验评估: 4×分辨率超分,然后对图像的每个边界移除4个像素点,最后center-cropped计算PSNR和SSIM,进行有效性统计分析。

随机采样ImageNet数据集中350张图像进行训练(参考源码):

SRGAN实验结果及分析

消融实验说明:

  1. skip-connection结构的有效性;
  2. PSNR体现不出人的感知(MOS);
  3. GAN-based Network能更好捕捉一些人的感知细节(高频信息?),MOS更高;
  4. VGG特征重建也有助于捕捉图像的部分感知细节。

联邦图机器学习

近年来,图已被广泛应用于表示和处理很多领域的复杂数据,如医疗、交通运输、生物信息学和推荐系统等。图机器学习技术是获取隐匿在复杂数据中丰富信息的有力工具,并且在像节点分类和链接预测等任务中,展现出很强的性能。
尽管图机器学习技术取得了重大进展,但大多数都需要把图数据集中存储在单机上。然而,随着对数据安全和用户隐私的重视,集中存储数据变的不安全和不可行。图数据通常分布在多个数据源(数据孤岛),由于隐私和安全的原因,从不同的地方收集所需的图数据变的不可行。
例如一家第三方公司想为一些金融机构训练图机器学习模型,以帮助他们检测潜在的金融犯罪和欺诈客户。每个金融机构都拥有私有客户数据,如人口统计数据以及交易记录等。每个金融机构的客户形成一个客户图,其中边代表交易记录。由于严格的隐私政策和商业竞争,各个机构的私有客户数据无法直接与第三方公司或其它他机构共享。同时,机构之间也可能有关联,这可以看作是机构之间的结构信息。因此面临的主要挑战是:在不直接访问每个机构的私有客户数据的情况下,基于私有客户图和机构间结构信息,来训练用于金融犯罪检测的图机器学习模型。
联邦学习(FL)是一种分布式机器学习方案,通过协作训练解决数据孤岛问题。它使参与者(即客户)能够在不共享其私有数据的情况下联合训练机器学习模型。因此,将 FL 与图机器学习相结合成为解决上述问题的有希望的解决方案。
本文中,来自弗吉尼亚大学的研究者提出联邦图机器学习(FGML,Federated Graph Machine Learning)。一般来说,FGML 可以根据结构信息的级别分为两种设置:
第一种是具有结构化数据的 FL,在具有结构化数据的 FL 中,客户基于其图数据协作训练图机器学习模型,同时将图数据保留在本地。
第二种是结构化 FL,在结构化 FL 中,客户端之间存在结构信息,形成客户端图。可以利用客户端图设计更有效的联合优化方法。

论文地址:https://arxiv.org/pdf/2207.11812.pdf
虽然 FGML 提供了一个有前景的蓝图,但仍存在一些挑战:
1、跨客户端的信息缺失。在具有结构化数据的 FL 中,常见的场景是每个客户端机器都拥有全局图的子图,并且一些节点可能具有属于其他客户端的近邻。出于隐私考虑,节点只能在客户端内聚合其近邻的特征,但无法访问位于其它客户端上的特征,这导致节点表示不足。
2、图结构的隐私泄漏。在传统 FL 中,不允许客户端公开其数据样本的特征和标签。在具有结构化数据的 FL 中,还应考虑结构信息的隐私。结构信息可以通过共享邻接矩阵直接公开,也可以通过传输节点嵌入间接公开。
3、跨客户端的数据异构性。与传统 FL 中数据异构性来自 non-IID 数据样本不同,FGML 中的图数据包含丰富的结构信息。同时,不同客户的图结构也会影响图机器学习模型的性能。 4、参数使用的策略。在结构化 FL 中,客户端图使客户端能够从其相邻客户端获取信息。在结构化 FL 中,需要设计有效的策略,以充分利用由中心服务器协调或完全分散的近邻信息。
为了应对上述挑战,研究人员开发了大量算法。目前各种算法主要关注标准 FL 中的挑战和方法,只有少数人尝试解决 FGML 中的具体问题和技术。有人发表对 FGML 进行分类的综述性论文,但没有总结 FGML 中的主要技术。而有的综述文章仅涵盖了 FL 中数量有限的相关论文,并非常简要地介绍了目前现有的技术。

而在今天介绍的这篇论文中,作者首先介绍 FGML 中两种问题设计的概念。然后,回顾了每种 shezhi 下的最新的技术进展,还介绍了 FGML 的实际应用。并对可用于 FGML 应用的可访问图数据集和平台进行总结。最后,作者给出了几个有前途的研究方向。文章的主要贡献包括:
FGML 技术分类:文章给出了基于不同问题的 FGML 分类法,并总结了每个设置中的关键挑战。
全面的技术回顾:文章全面概述了 FGML 中的现有技术。与现有其它综述性论文相比,作者不仅研究了更广泛的相关工作,而且提供了更详细的技术分析,而不是简单地列出每种方法的步骤。
实际应用:文章首次总结 FGML 的实际应用。作者根据应用领域对其进行分类,并介绍每个领域中的相关工作。
数据集和平台:文章介绍了 FGML 中现有的数据集和平台,对于想在 FGML 中开发算法和部署应用程序的工程师和研究人员非常有帮助。
未来方向:文章不仅指出了现有方法的局限性,而且给出了 FGML 未来的发展方向。

这里对文章的主要结构做下简介。第 2 节简要介绍了图机器学习中的定义以及 FGML 中两种设置的概念和挑战。第 3 节和第 4 节回顾了这两种设置中的主流技术。第 5 节进一步探讨了 FGML 在现实世界中的应用。第 6 节介绍了相关 FGML 论文中使用的开放图数据集和 FGML 的两个平台。在第 7 节中提供了未来可能的发展方向。最后第 8 节对全文进行了总结。
更多详细信息请参考原论文。

用语言模型学习表示蛋白质的功能特性

以数据为中心的方法已被用于开发用于阐明蛋白质未表征特性的预测方法;然而,研究表明,这些方法应进一步改进,以有效解决生物医学和生物技术中的关键问题,这可以通过更好地代表手头的数据来实现。新的数据表示方法主要从在自然语言处理方面取得突破性改进的语言模型中汲取灵感。最近,这些方法已应用于蛋白质科学领域,并在提取复杂的序列-结构-功能关系方面显示出非常有希望的结果。在这项研究中,土耳其中东科技大学(Middle East Technical University)的研究人员,首先对每种方法进行分类/解释,然后对它们的预测性能进行基准测试,对蛋白质表示学习进行了详细调查:(1)蛋白质之间的语义相似性,(2)基于本体的蛋白质功能,(3)药物靶蛋白家族和(4)突变后蛋白质-蛋白质结合亲和力的变化。这项研究的结论将有助于研究人员将基于机器/深度学习的表示技术应用于蛋白质数据以进行各种预测任务,并激发新方法的发展。该研究以「Learning functional properties of proteins with language models」为题,于 2022 年 3 月 21 日发布在《Nature Machine Intelligence》。

蛋白质科学是一门广泛的学科,它通过实验室实验(即蛋白质组学)和计算方法(例如分子建模、机器学习、数据科学)分析单个蛋白质以及生物体的整个蛋白质组,最终创建准确且可重复使用的方法用于生物医学和生物技术。蛋白质信息学可以定义为蛋白质科学的计算和以数据为中心的分支,通过它对蛋白质的定量方面进行建模。蛋白质的功能表征对于开发新的有效的生物医学策略和生物技术产品至关重要。截至 2021 年 5 月,UniProt 蛋白质序列和注释知识库中约有 2.15 亿条蛋白质条目;然而,其中只有 56 万份(约 0.26%)由专家手动审查和注释,这表明当前的排序(数据生产)和注释(标签)能力之间存在很大差距。这种差距主要是由于从湿实验室实验及其手动管理中获得结果的成本较高,同时具有时间密集性。为了补充基于实验和管理的注释,使用计算机方法势在必行。在这种情况下,许多研究小组一直致力于开发新的计算方法来预测蛋白质的酶活性、生物物理特性、蛋白质和配体相互作用、三维结构以及最终的功能。蛋白质功能预测(PFP)可以定义为自动或半自动地将功能定义分配给蛋白质。生物分子功能的主要术语被编入基因本体论(GO)系统;这是一个概念的分层网络,用于注释基因和蛋白质的分子功能,以及它们的亚细胞定位和它们所涉及的生物过程。PFP 最全面的基准项目是功能注释的关键评估(CAFA)挑战;在该项目中,参与者预测一组目标蛋白的基于 GO 的功能关联,这些目标蛋白的功能后来通过手动调节确定,用于评估参与预测因子的性能;迄今为止的 CAFA 挑战表明,PFP 仍然是一个开放的问题。以前的研究已经表明,复杂的计算问题,其中特征是高维的并且具有复杂/非线性关系,适合基于深度学习的技术。这些技术可以有效地从嘈杂的高维输入数据中学习与任务相关的表示。因此,深度学习已成功应用于计算机视觉、自然语言处理和生命科学等各个领域。生物分子的特征(例如,基因、蛋白质、RNA 等)应被提取并编码为定量/数值向量(即表示),以用于基于机器/深度学习的预测建模。给定生物分子的原始和高维输入特征,表示模型将该特征向量计算为该生物分子的简洁和正交表示。经过优化训练的监督预测系统可以有效地学习数据集中样本的特征,并使用这些表示作为输入来执行预测任务(例如,序列上的 DNA 结合区域、生化特性、亚细胞定位等)。蛋白质表示方法可以分为两大类;(1)经典表示(即模型驱动的方法),使用预定义的属性规则生成,例如基因/蛋白质之间的进化关系或氨基酸的物理化学性质,以及(2)数据驱动的表示,使用统计和机器学习算法(例如人工神经网络)构建,这些算法针对预定义任务进行训练,例如预测序列上的下一个氨基酸。之后,训练模型的输出——即表示特征向量——可以用于其他与蛋白质信息学相关的任务,例如功能预测。从这个意义上说,表示学习模型利用了知识从一个任务到另一个任务的转移。这个过程的广义形式被称为迁移学习,据报道它在时间和成本方面是一种高效的数据分析方法。因此,蛋白质表示学习模型最大限度地减少了对数据标记的需求。蛋白质表示学习是一个年轻但高度活跃的研究领域,主要受到自然语言处理 (NLP) 方法的启发。因此,蛋白质表示学习方法在文献中经常被称为蛋白质语言模型。之前的研究表明,各种蛋白质表示学习方法,尤其是那些结合了深度学习的方法,已经成功地提取了蛋白质的相关固有特征。参见:https://www.nature.com/articles/s42256-022-00457-9/tables/1尽管有研究评估学习的蛋白质表示模型,但需要进行全面的调查和基准测试,以便在学习蛋白质的多个方面(包括基于本体的功能定义、语义关系、家族和相互作用)的背景下系统地评估这些方法。在新的研究中,中东科技大学的研究人员对自 2015 年以来提出的可用蛋白质表示学习方法进行了全面调查,并通过详细的基准分析测量了这些方法捕获蛋白质功能特性的潜力。涵盖了经典和基于人工学习的方法,并深入了解了它们各自代表蛋白质的方法。研究人员根据它们的技术特征和应用对这些方法进行分类。为了评估每个表示模型在多大程度上捕获了功能信息的不同方面,该团队构建并应用了基于以下的基准:(1) 蛋白质之间的语义相似性推断,(2) 基于本体的 PFP,(3) 药物靶蛋白家族分类,(4) 蛋白质-蛋白质结合亲和力估计。

图示:研究的示意图。(来源:论文)此外,该团队还提供了相关的基准测试软件(Protein Representation Benchmark, PROBE),它允许人们轻松评估任何表示方法在该团队定义的四个基准测试任务中的性能。研究人员希望该工作能够,为希望将基于机器/深度学习的表示技术,应用于生物分子数据进行预测建模的研究人员提供信息。也希望这项研究能够激发新的想法,以开发新颖、复杂和强大的以数据为中心的方法来解决蛋白质科学中的开放问题。基于表示学习的方法在蛋白质功能分析中的表现通常优于经典方法在该团队所有的基准测试中,观察到学习表示(尤其是大型模型)在预测性能方面优于经典模型,证实了基于人工学习的数据驱动方法在表示生物分子的功能特性方面的优势。另一方面,在 PFP 预测基准的分子功能类别中,HMMER 是一种基于隐马尔可夫模型(HMM)的生物分子相似性检测和功能注释的经典方法,可以与基于深度学习的蛋白质表示方法竞争。该结果与先前的研究一致,即序列相似性与蛋白质的生化特性高度相关,以至于使用此特征的简单矢量表示几乎可以执行复杂的序列建模方法。鉴于这些结果,研究人员表示,将同源信息明确纳入表征学习模型的训练可能会导致考虑到预测性能的改进。这从基于深度学习的高性能蛋白质结构预测器(例如 RoseTTAFold 和 AlphaFold2)中也很明显,它们使用多个序列比对来显着丰富基于序列的输入。他们认为,在当前状态下,学习到的蛋白质表示对于其他原因也是必不可少的。模型设计和训练数据类型/来源是表征学习的关键因素蛋白质表征学习中最关键的因素之一是表征模型的设计。例如,在这里的基准测试中,包含了两种类型的 BERT 模型。TAPE-BERT-PFAM 接受了 3200 万个蛋白质结构域序列的训练。ProtBERT-BFD 训练有 21 亿个宏基因组序列片段;然而,这两者之间的性能差异是微不足道的。另一方面,使用相同 2.1B 数据集(例如 ProtT5-XL)训练的更复杂的模型在大多数基准测试中表现出更好的性能。因此,研究人员认为模型设计/架构是最重要的(与这些方法的设计/架构相关的信息在方法中给出,并在结果部分就预测性能进行了讨论)。关于训练数据源的另一个发现是,合并多种数据类型可能会在与功能相关的预测任务中带来更好的性能。例如,AAC 和 APAAC 都使用氨基酸组成;然而,APAAC 还在其表示模型中添加了物理化学特性,并且在语义相似性推断和 PFP 基准测试中表现得更好。同样,Mut2Vec 结合了突变配置文件、PPI 和文本数据,并取得了最佳性能;尤其是在语义相似性推理基准测试中。在蛋白质表征学习方法的构建和评估过程中应考虑潜在的数据泄漏数据泄漏可以定义为机器学习方法的训练和验证阶段之间的知识意外泄漏,导致性能测量过于乐观,是性能测试期间应考虑的关键问题。研究人员分析发现,某些表示模型在与这些模型预训练的任务生物学相关的任务中表现良好;尽管数据和实际任务彼此不同。蛋白质表征学习的现状和挑战蛋白质表示学习领域存在一些挑战。尽管大多数蛋白质表示学习模型(迄今为止提出的)都是源自 NLP 模型(基于 LSTM/transformer 的深度学习模型),但建模语言和蛋白质的问题之间存在结构差异。据估计,一个以美国为母语的成年英语使用者,平均使用 46,200 个词条和多词表达;然而,蛋白质中只有 20 种不同的氨基酸,它们被表示模型以类似于语言的引理的方式处理。这些 NLP 模型为每个单词计算一个表示向量。类似地,当这种方法应用于蛋白质序列数据时,会为每个氨基酸计算一个表示向量。这些向量被汇集起来,为每个句子/文档和蛋白质创建固定大小的向量,分别用于 NLP 和蛋白质信息学任务。因此,与 NLP 相比,蛋白质表示中的少量构建块(即 20 个氨基酸)可能为较小的模型在与蛋白质表示学习领域中的较大模型竞争时带来优势。因此,鼓励对蛋白质序列特异性学习模型进行更多研究。

图示:蛋白质语义相似性推理基准结果。(来源:论文)模型的可解释性对于理解模型为何如此行事至关重要。在可解释表示中,所有特征都以隔离形式编码,这意味着向量上每个位置对应的特征是已知的;然而,该研究中研究的大多数学习蛋白质表示是不可解释/可解释的。例如,蛋白质中 TIM 桶结构的存在可能在其表示向量的第五位编码,而分子量信息可能在第三和第四位之间共享。一般来说,在数据科学领域,解缠结研究试图将样本的真实属性与输出向量的各个位置联系起来。蛋白质表征的解开是一个新课题,迄今为止只有少数表征模型开发人员探讨了这个问题。因此,尚不存在系统方法,并且需要新的框架来标准化评估蛋白质表示模型的可解释性。迄今为止提出的大多数蛋白质表示模型仅使用一种类型的数据(例如蛋白质序列)进行训练。然而,蛋白质知识与多种类型的生物信息相关,例如 PPI、翻译后修饰、基因/蛋白质(共)表达等;只有少数可用的蛋白质表示模型使用了多种类型的数据。在该团队的基准研究时所涉及的方法中,Mut2Vec 就是一个这样的例子,它结合了 PPI、突变和生物医学文本,并且比 GO BP 和基于 CC 的 PFP 中的许多仅基于序列的表示产生了更准确的结果。研究人员建议整合其他类型的蛋白质相关数据,尤其是进化关系,可能会进一步提高预测任务的准确性。MSA-Transformer 和无向图模型(例如 DeepSequence)通过深度学习利用同源信息。DeepSequence 使用 MSA 的后验分布计算潜在因子,而 MSA-Transformer 使用基于行和列的注意力来结合 MSA 和蛋白质语言模型。尽管 MSA-Transformer 在基准测试中表现出平均性能,但在之前的文献中,发现它在二级结构和接触预测任务上是成功的,这表明 MSA-Transformer 具有捕捉进化关系的能力。与此相关的是,文献中明确要求能够从广义的角度有效地表示蛋白质的整体蛋白质载体,用于各种不同的蛋白质信息学相关目的。研究人员认为,可以通过连接多个先前使用不同类型的生物数据独立构建的表示向量来创建这些整体表示,并使用这些向量的集成版本训练新模型以用于高级监督任务,例如预测生物过程和/或复杂的结构特征。构建这些整体表示的另一种方法是通过图表示学习直接学习整合多种蛋白质关系(例如,其他蛋白质、配体、疾病、表型、功能、途径等)的异构图。

图示:基于本体的蛋白质功能预测基准结果。(来源:论文)蛋白质表示学习方法可用于设计新蛋白质蛋白质设计是生物技术的主要挑战之一。合理的蛋白质设计涉及评估许多不同替代序列/结构的活性和功能,以为实验验证提供最有希望的候选者,这可以看作是一个优化问题。为此目的要探索的序列空间是巨大的。例如,人类蛋白质的平均长度约为 350 个氨基酸,其中存在 20^350 种不同的组合,尽管其中大多数是非功能性序列。在过去的几十年中,计算方法已被用于蛋白质设计,并且这些方法已经产生了有希望的结果,特别是在酶设计、蛋白质折叠和组装以及蛋白质表面设计方面。因此已经开发出高效的抗体和生物传感器。其中一些方法使用量子力学计算、分子动力学和统计力学,每种方法的计算成本都非常高,并且需要专业知识。类似的缺点也表现在主流的蛋白质设计软件,如 Rosetta。最近的研究表明,基于人工学习的生成建模可用于从头蛋白质设计。在机器学习领域,生成建模与判别建模相反,是一种生成合成样本的方法,这些样本服从从真实样本中学习到的概率分布。这是通过有效地学习训练数据集中样本的表示来实现的。

图示:药物靶蛋白家族分类基准结果。(来源:论文)深度学习最近已成为生成模型架构的关键方法,并已应用于包括蛋白质/肽设计在内的各个领域。例如,Madani 团队使用蛋白质语言模型从头开始设计属于不同蛋白质家族的新功能蛋白质,并通过湿实验室实验验证了他们的设计。这些研究表明,表示学习对于蛋白质和配体(药物)设计中的新应用至关重要。

图示:蛋白质-蛋白质结合亲和力估计基准结果。(来源:论文)研究人员相信蛋白质表示学习方法将在不久的将来对蛋白质科学的各个领域产生影响,并在现实世界中应用,这要归功于它们在输入级别集成异构蛋白质数据(即理化性质/属性、功能注释等)的灵活性,以及它们有效提取复杂潜在特征的能力。论文链接:https://www.nature.com/articles/s42256-022-00457-9

CE Loss 与 BCE Loss (分类问题损失函数)

有两个问题曾困扰着我:

  1. 为何MSE loss是一种回归问题的loss,不可以用在分类问题?而非要用CE或BCE呢?
  2. 为何CE与softmax激活函数搭配,而BCE与sigmoid搭配?有什么理由?

在学习过后,我发现这个问题在数学上有多种理解的角度,而结论却是一致的。在这里我梳理出一种角度,在未来如果有新的理解再进行补充。

该思路学习自李宏毅老师的机器学习课程

这说明,如果用MSE loss来训练分类问题,不论预测接近真实值或是接近错误值,梯度都很小。这也就解释了为何我们需要CE或BCE损失来处理分类问题。

2. BCE 损失函数

既然在分类问题中,MSE损失函数的梯度不能满足需要,现在我们来推导BCE损失函数的梯度。

3. CE 损失函数

江湖上流传一句话,交叉熵损失可以采用“sigmoid+BCE”或是“softmax+CE”。对于前者,上一部分已经对其回传的梯度进行了推导,证实了其合理性。

在这一部分我们对“softmax+CE”的梯度进行推导。

4. 应用

在Pytorch中,“sigmoid+BCE”对应的是torch.nn.BCEWithLogitsLoss,而“softmax+CE”对应的是torch.nn.CrossEntropyLoss

具体参数和用法可以参考 BCEWithLogitsLoss 和 CrossEntropyLoss

在分类问题中,如果遇到类别间不互斥的情况,只能采用“sigmoid+BCE”;

如果遇到类别间互斥的情况(只能有一类胜出),“sigmoid+BCE”化为多个二分类问题与“softmax+CE”直接进行分类都是有被用到的方法。

Softmax函数和Sigmoid函数的区别与联系

1. 前言

对于Softmax函数和Sigmoid函数,我们分为两部分讲解,第一部分:对于分类任务,第二部分:对于二分类任务(详细讲解)。

优点:1. Sigmoid函数的输出在(0,1)之间,输出范围有限,优化稳定,可以用作输出层。2. 连续函数,便于求导。

缺点:1. 最明显的就是饱和性,从上图也不难看出其两侧导数逐渐趋近于0,容易造成梯度消失。2.激活函数的偏移现象。Sigmoid函数的输出值均大于0,使得输出不是0的均值,这会导致后一层的神经元将得到上一层非0均值的信号作为输入,这会对梯度产生影响。 3. 计算复杂度高,因为Sigmoid函数是指数形式。

2.2 Softmax函数

Softmax =多类别分类问题=只有一个正确答案=互斥输出(例如手写数字,鸢尾花)。构建分类器,解决只有唯一正确答案的问题时,用Softmax函数处理各个原始输出值。Softmax函数的分母综合了原始输出值的所有因素,这意味着,Softmax函数得到的不同概率之间相互关联。

Softmax函数是二分类函数Sigmoid在多分类上的推广,目的是将多分类的结果以概率的形式展现出来。如图2所示,Softmax直白来说就是将原来输出是3,1,-3通过Softmax函数一作用,就映射成为(0,1)的值,而这些值的累和为1(满足概率的性质),那么我们就可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率最大(也就是值对应最大的)结点,作为我们的预测目标。

由于Softmax函数先拉大了输入向量元素之间的差异(通过指数函数),然后才归一化为一个概率分布,在应用到分类问题时,它使得各个类别的概率差异比较显着,最大值产生的概率更接近1,这样输出分布的形式更接近真实分布。

Softmax可以由三个不同的角度来解释。从不同角度来看softmax函数,可以对其应用场景有更深刻的理解:

  1. softmax可以当作arg max的一种平滑近似,与arg max操作中暴力地选出一个最大值(产生一个one-hot向量)不同,softmax将这种输出作了一定的平滑,即将one-hot输出中最大值对应的1按输入元素值的大小分配给其他位置。
  2. softmax将输入向量归一化映射到一个类别概率分布,即 n 个类别上的概率分布(前文也有提到)。这也是为什么在深度学习中常常将softmax作为MLP的最后一层,并配合以交叉熵损失函数(对分布间差异的一种度量)。
  3. 从概率图模型的角度来看,softmax的这种形式可以理解为一个概率无向图上的联合概率。因此你会发现,条件最大熵模型与softmax回归模型实际上是一致的,诸如这样的例子还有很多。由于概率图模型很大程度上借用了一些热力学系统的理论,因此也可以从物理系统的角度赋予softmax一定的内涵。

2.3 总结

  1. 如果模型输出为非互斥类别,且可以同时选择多个类别,则采用Sigmoid函数计算该网络的原始输出值。
  2. 如果模型输出为互斥类别,且只能选择一个类别,则采用Softmax函数计算该网络的原始输出值。
  3. Sigmoid函数可以用来解决多标签问题,Softmax函数用来解决单标签问题。[1]
  4. 对于某个分类场景,当Softmax函数能用时,Sigmoid函数一定可以用。

3. 二分类任务

对于二分类问题来说,理论上,两者是没有任何区别的。由于我们现在用的Pytorch、TensorFlow等框架计算矩阵方式的问题,导致两者在反向传播的过程中还是有区别的。实验结果表明,两者还是存在差异的,对于不同的分类模型,可能Sigmoid函数效果好,也可能是Softmax函数效果。

然后我们再分析为什么两者之间还存着差异(以Pytorch为例):

首先我们要明白,当你用Sigmoid函数的时候,你的最后一层全连接层的神经元个数为1,而当你用Softmax函数的时候,你的最后一层全连接层的神经元个数是2。这个很好理解,因为Sigmoid函数只有是目标和不是目标之分,实际上只存在一类目标类,另外一个是背景类。而Softmax函数将目标分类为了二类,所以有两个神经元。这也是导致两者存在差异的主要原因。

Sigmoid函数针对两点分布提出。神经网络的输出经过它的转换,可以将数值压缩到(0,1)之间,得到的结果可以理解成分类成目标类别的概率P,而不分类到该类别的概率是(1 – P),这也是典型的两点分布的形式。

Softmax函数本身针对多项分布提出,当类别数是2时,它退化为二项分布。而它和Sigmoid函数真正的区别就在——二项分布包含两个分类类别(姑且分别称为A和B),而两点分布其实是针对一个类别的概率分布,其对应的那个类别的分布直接由1-P得出。

简单点理解就是,Sigmoid函数,我们可以当作成它是对一个类别的“建模”,将该类别建模完成,另一个相对的类别就直接通过1减去得到。而softmax函数,是对两个类别建模,同样的,得到两个类别的概率之和是1。

神经网络在做二分类时,使用Softmax还是Sigmoid,做法其实有明显差别。由于Softmax是对两个类别(正反两类,通常定义为0/1的label)建模,所以对于NLP模型而言(比如泛BERT模型),Bert输出层需要通过一个nn.Linear()全连接层压缩至2维,然后接Softmax(Pytorch的做法,就是直接接上torch.nn.CrossEntropyLoss);而Sigmoid只对一个类别建模(通常就是正确的那个类别),所以Bert输出层需要通过一个nn.Linear()全连接层压缩至1维,然后接Sigmoid(torch就是接torch.nn.BCEWithLogitsLoss)。

总而言之,Sotfmax和Sigmoid确实在二分类的情况下可以化为相同的数学表达形式,但并不意味着二者有一样的含义,而且二者的输入输出都是不同的。Sigmoid得到的结果是“分到正确类别的概率和未分到正确类别的概率”,Softmax得到的是“分到正确类别的概率和分到错误类别的概率”。

一种常见的错法(NLP中):即错误地将Softmax和Sigmoid混为一谈,再把BERT输出层压缩至2维的情况下,却用Sigmoid对结果进行计算。这样我们得到的结果其意义是什么呢?

假设我们现在BERT输出层经nn.Linear()压缩后,得到一个二维的向量:

[-0.9419267177581787, 1.944047451019287]

对应类别分别是(0,1)。我们经过Sigmoid运算得到:

tensor([0.2805, 0.8748])

前者0.2805指的是分类类别为0的概率,0.8748指的是分类类别为1的概率。二者相互独立,可看作两次独立的实验(显然在这里不适用,因为0-1类别之间显然不是相互独立的两次伯努利事件)。所以显而易见的,二者加和并不等于1。

若用softmax进行计算,可得:

tensor([0.0529, 0.9471])

这里两者加和是1,才是正确的选择。

经验:

对于NLP而言,这两者之间确实有差别,Softmax的处理方式有时候会比Sigmoid的处理方式好一点。

对于CV而言,这两者之间也是有差别的,Sigmoid的处理方式有时候会比Softmax的处理方式好一点。

逼真度超越「AI设计师」DALL·E 2!谷歌大脑推出新的文本生成图像模型——Imagen

论文: https://arxiv.org/abs/2205.11487

demo地址:https://imagen.research.google/

文本生成图像模型界又出新手笔!

这次的主角是Google Brain推出的 Imagen,再一次突破人类想象力,将文本生成图像的逼真度和语言理解提高到了前所未有的新高度!比前段时间OpeAI家的DALL·E 2更强!

话不多说,我们来欣赏这位AI画师的杰作~A brain riding a rocketship heading towards the moon.(一颗大脑乘着火箭飞向月球。)


Imagen的工作原理

本方案的主要内容包括三部分,如下图所示,首先是文本编码器部分,本文直接使用的是T5,然后是Diffusion生成模型,这部分与Glide类似,都是使用classifier-free引导的方式。最后就是对生成的小图进行超分,变为大图。下面分模块详细介绍:

2.1. text编码部分

文本编码器部分对比了BERT(base模型参数量:1.1亿)、CLIP(0.63亿)以及T5(模型参数量:110亿),后来发现T5效果最好。并且还舍弃了之前常规的基于<text, image>数据对,对Text Encoder进行finetune的流程。理由个很直接,因为参数量大好几个量级,不需要finetune。

2.2. Diffusion生成部分

这部分跟Glide中的基本相近,可以直接与Glide文章中对eps建模公式进行对比。只是在uncondition的时候没有使用空的文本表示。

text condition diffusion model using classifier-free guidance

classifier-free应该是diffusion必备的优化方式了。融合text特征到生成模型中的部分也可以直接看Glide。这部分的模型还是典型的64*64的U-Net结构,如下图2所示。之所以选择小模型主要还是diffusion的迭代过程太长,导致生成过程慢,所以生成小图是提速最方便的,但是也注定了无法生成比较复杂内容和空间关系的大图。UNet网络由左编码部分,右解码部分和下两个卷积+激活层组成。

编码部分:左边红框架构中是由4个重复结构组成:2个3×3卷积层,非线形ReLU层和一个stride为2的2×2 max pooling层。每一次下采样特征通道的数量加倍。

解码部分:右边蓝框,反卷积也有4个重复结构组成。每个重复结构前先使用反卷积,每次反卷积后特征通道数量减半,特征图大小加倍反卷积之后,反卷积的结果和编码部分对应步骤的特征图拼接起来。拼接后的特征图再进行2次3×3的卷积,最后一层的卷积核为1×1 的卷积核,将64通道的特征图转化为特定类别数量的结果

2.3. Diffusion超分部分

超分的好处是可以直接带来效率的提高,但是可能会影响最终生成的细节失真,本文在本文提到通过噪声的增强,可以提升模型在控制失真上鲁棒性,具体原理还是要详细看论文了。这部分模型使用的是U-Net的变体Efficient U-Net模型,有点就是提升记忆感知、推理效率以及训练收敛速度。

大型预训练语言模型×级联扩散模型

Imagen使用在纯文本语料中进行预训练的通用大型语言模型(例如T5),它能够非常有效地将文本合成图像:在Imagen中增加语言模型的大小,而不是增加图像扩散模型的大小,可以大大地提高样本保真度和图像-文本对齐。

Imagen的研究突出体现在:

  • 大型预训练冻结文本编码器对于文本到图像的任务来说非常有效;
  • 缩放预训练的文本编码器大小比缩放扩散模型大小更重要;
  • 引入一种新的阈值扩散采样器,这种采样器可以使用非常大的无分类器指导权重;
  • 引入一种新的高效U-Net架构,这种架构具有更高的计算效率、更高的内存效率和更快的收敛速度;
  • Imagen在COCO数据集上获得了最先进的FID分数7.27,而没有对COCO进行任何训练,人类评分者发现,Imagen样本在图像-文本对齐方面与COCO数据本身不相上下。

2
引入新基准DrawBench

为了更深入地评估文本到图像模型,Google Brain 引入了DrawBench,这是一个全面的、具有挑战性的文本到图像模型基准。通过DrawBench,他们比较了Imagen与VQ-GAN+CLIP、Latent Diffusion Models和DALL-E 2等其他方法,发现人类评分者在比较中更喜欢Imagen而不是其他模型,无论是在样本质量上还是在图像-文本对齐方面。

  • 并排人类评估;
  • 对语意合成性、基数性、空间关系、长文本、生词和具有挑战性的提示几方面提出了系统化的考验;
  • 由于图像-文本对齐和图像保真度的优势,相对于其他方法,用户强烈倾向于使用Imagen。

3 打开了潘多拉魔盒?

像Imagen这样从文本生成图像的研究面临着一系列伦理挑战。

首先,文本-图像模型的下游应用多种多样,可能会从多方面对社会造成影响。Imagen以及一切从文本生成图像的系统都有可能被误用的潜在风险,因此社会要求开发方提供负责任的开源代码和演示。基于以上原因,Google决定暂时不发布代码或进行公开演示。而在未来的工作中,Google将探索一个负责任的外部化框架,从而将各类潜在风险最小化。

其次,文本到图像模型对数据的要求导致研究人员严重依赖于大型的、大部分未经整理的、网络抓取的数据集。虽然近年来这种方法使算法快速进步,但这种性质的数据集往往会夹带社会刻板印象、压迫性观点、对边缘群体有所贬损等“有毒”信息。

为了去除噪音和不良内容(如色情图像和“有毒”言论),Google对训练数据的子集进行了过滤,同时Google还使用了众所周知的LAION-400M数据集进行过滤对比,该数据集包含网络上常见的不当内容,包括色情图像、种族主义攻击言论和负面社会刻板印象。Imagen依赖于在未经策划的网络规模数据上训练的文本编码器,因此继承了大型语言模型的社会偏见和局限性。这说明Imagen可能存在负面刻板印象和其他局限性,因此Google决定,在没有进一步安全措施的情况下,不会将Imagen发布给公众使用。

OpenAI 发布文字生成图像工具 DALL·E 2

paper: https://arxiv.org/abs/2204.06125

官方地址: https://openai.com/dall-e-2/

论文讲解:DALL·E 2【论文精读】

Dall-mini:一个小型的Dall开源库

提供一个api: https://www.craiyon.com/

DALL·E 2 能做什么:

官网

DALL-E 2不仅能按用户指令生成明明魔幻,却又看着十分合理不明觉厉的图片。作为一款强大的模型,目前我们已知DALL-E 2还可以:

  • 生成特定艺术风格的图像,仿佛出自该种艺术风格的画家之手,十分原汁原味!
  • 保持一张图片显着特征的情况下,生成该图片的多种变体,每一种看起来都十分自然;
  • 修改现有图像而不露一点痕迹,天衣无缝。

1、根据文字生成图片(概念 属性 风格):

An astronaut riding a horse in a photorealistic style
A bowl of soup that is a portal to another dimension as digital art

2、 DALL·E 2 can take an image and create different
variations of it inspired by the original.(输入图片,生成跟原始图片相似的图片)

DALL-E 2目前曝光的功能令人瞠目结舌,不禁激起了众多AI爱好者的讨论,这样一个强大模型,它的工作原理到底是什么?!

3、DALL·E 2 can make realistic edits to existing images from a natural language caption. It can add and remove elements while taking shadows, reflections, and textures into account.

·

1、工作原理:简单粗暴

图源:https://arxiv.org/abs/2204.0612

针对图片生成这一功能来说,DALL-E 2的工作原理剖析出来看似并不复杂:

  1. 首先,将文本提示输入文本编码器,该训练过的编码器便将文本提示映射到表示空间。
  2. 接下来,称为先验的模型将文本编码映射到相应的图像编码,图像编码捕获文本编码中包含的提示的语义信息。
  3. 最后,图像解码模型随机生成一幅从视觉上表现该语义信息的图像。

2、工作细节:处处皆奥妙

可是以上步骤说起来简单,分开看来却是每一步都有很大难度,让我们来模拟DALL-E 2的工作流程,看看究竟每一步都是怎么走通的。

我们的第一步是先看看DALL-E 2是怎么学习把文本和视觉图像联系起来的。

第一步 – 把文本和视觉图像联系起来

输入“泰迪熊在时代广场滑滑板”的文字提示后,DALL-E 2生成了下图:

DALL-E 2是怎么知道“泰迪熊”这个文本概念在视觉空间里是什么样子的?

其实DALL-E 2中的文本语义和与其相对的视觉图片之间的联系,是由另一个OpenAI模型CLIP(Contrastive Language-Image Pre-training)学习的。

CLIP接受过数亿张图片及其相关文字的训练,学习到了给定文本片段与图像的关联。

也就是说,CLIP并不是试图预测给定图像的对应文字说明,而是只学习任何给定文本与图像之间的关联。CLIP做的是对比性而非预测性的工作。

整个DALL-E 2模型依赖于CLIP从自然语言学习语义的能力,所以让我们看看如何训练CLIP来理解其内部工作。

CLIP训练

训练CLIP的基本原则非常简单:

  1. 首先,所有图像及其相关文字说明都通过各自的编码器,将所有对象映射到m维空间。
  2. 然后,计算每个(图像,文本)对的cos值相似度。
  3. 训练目标是使N对正确编码的图像/标题对之间的cos值相似度最大化,同时使N2 – N对错误编码的图像/标题对之间的cos值相似度最小化。
CLIP训练流程

CLIP对DALL-E 2的意义

CLIP几乎就是DALL-E 2的心脏,因为CLIP才是那个把自然语言片段与视觉概念在语义上进行关联的存在,这对于生成与文本对应的图像来说至关重要。

第二步 – 从视觉语义生成图像

训练结束后,CLIP模型被冻结,DALL-E 2进入下一个任务——学习怎么把CLIP刚刚学习到的图像编码映射反转。CLIP学习了一个表示空间,在这个表示空间当中很容易确定文本编码和视觉编码的相关性, 我们需要学会利用表示空间来完成反转图像编码映射这个任务。

而OpenAI使用了它之前的另一个模型GLIDE的修改版本来执行图像生成。GLIDE模型学习反转图像编码过程,以便随机解码CLIP图像嵌入。

“一只吹喷火喇叭的柯基”一图经过CLIP的图片编码器,GLIDE利用这种编码生成保持原图像显着特征的新图像。 图源:https://arxiv.org/abs/2204.06125

如上图所示,需要注意的是,我们的目标不是构建一个自编码器并在给定的嵌入条件下精确地重建图像,而是在给定的嵌入条件下生成一个保持原始图像显着特征的图像。为了进行图像生成,GLIDE使用了扩散模型(Diffusion Model)。

何为扩散模型?

扩散模型是一项受热力学启发的发明,近年来越来越受到学界欢迎。扩散模型学习通过逆转一个逐渐噪声过程来生成数据。如下图所示,噪声处理过程被视为一个参数化的马尔可夫链,它逐渐向图像添加噪声使其被破坏,最终(渐近地)导致纯高斯噪声。扩散模型学习沿着这条链向后走去,在一系列步骤中逐渐去除噪声,以逆转这一过程。

扩散模型示意图 图源:https://arxiv.org/pdf/2006.11239.pdf

如果训练后将扩散模型“切成两半”,则可以通过随机采样高斯噪声来生成图像,然后对其去噪,生成逼真的图像。大家可能会意识到这种技术很容易令人联想到用自编码器生成数据,实际上扩散模型和自编码器确实是相关的。

GLIDE的训练

虽然GLIDE不是第一个扩散模型,但其重要贡献在于对模型进行了修改,使其能够生成有文本条件的图像。

GLIDE扩展了扩散模型的核心概念,通过增加额外的文本信息来增强训练过程,最终生成文本条件图像。让我们来看看GLIDE的训练流程:

动图封面

下面是一些使用GLIDE生成的图像示例。作者指出,就照片真实感和文本相似度两方面而言,GLIDE的表现优于DALL-E(1)。

DALL-E 2使用了一种改进的GLIDE模型,这种模型以两种方式使用投影的CLIP文本嵌入。第一种方法是将它们添加到GLIDE现有的时间步嵌入中,第二种方法是创建四个额外的上下文标记,这些标记连接到GLIDE文本编码器的输出序列。

GLIDE对于DALL-E 2的意义

GLIDE对于DALL-E 2亦很重要,因为GLIDE能够将自己按照文本生成逼真图像的功能移植到DALL-E 2上去,而无需在表示空间中设置图像编码。因此,DALL-E 2使用的修改版本GLIDE学习的是根据CLIP图像编码生成语义一致的图像。

第三步 – 从文本语义到相应的视觉语义的映射

到了这步,我们如何将文字提示中的文本条件信息注入到图像生成过程中?

回想一下,除了图像编码器,CLIP还学习了文本编码器。DALL-E 2使用了另一种模型,作者称之为先验模型,以便从图像标题的文本编码映射到对应图像的图像编码。DALL-E 2的作者用自回归模型和扩散模型进行了实验,但最终发现它们的性能相差无几。考虑到扩散模型的计算效率更高,因此选择扩散模型作为 DALL-E 2的先验。

从文本编码到相应图像编码的先验映射 修改自图源:https://arxiv.org/abs/2204.06125

先验训练

DALL-E 2中扩散先验的运行顺序是:

  1. 标记化的文本;
  2. 这些标记的CLIP文本编码;
  3. 扩散时间步的编码;
  4. 噪声图像通过CLIP图像编码器;
  5. Transformer输出的最终编码用于预测无噪声CLIP图像编码。

第四步 – 万事俱备

现在,我们已经拥有了DALL-E 2的所有“零件”,万事俱备,只需要将它们组合在一起就可以获得我们想要的结果——生成与文本指示相对应的图像:

  1. 首先,CLIP文本编码器将图像描述映射到表示空间;
  2. 然后扩散先验从CLIP文本编码映射到相应的CLIP图像编码;
  3. 最后,修改版的GLIDE生成模型通过反向扩散从表示空间映射到图像空间,生成众多可能图像中的一个。
DALL-E 2图像生成流程的高级概述 修改自图源:https://arxiv.org/abs/2204.06125

以上就是DALL-E 2的工作原理啦~

希望大家能注意到DALL-E 2开发的3个关键要点

  • DALL-E 2体现了扩散模型在深度学习中的能力,DALL-E 2中的先验子模型和图像生成子模型都是基于扩散模型的。虽然扩散模型只是在过去几年才流行起来,但其已经证明了自己的价值,我们可以期待在未来的各种研究中看到更多的扩散模型~
  • 第二点是我们应看到使用自然语言作为一种手段来训练最先进的深度学习模型的必要性与强大力量。DALL-E 2的强劲功能究其根本还是来自于互联网上提供的绝对海量的自然语言&图像数据对。使用这些数据不仅消除了人工标记数据集这一费力的过程所带来的发展瓶颈;这些数据的嘈杂、未经整理的性质也更加反映出深度学习模型必须对真实世界的数据具有鲁棒性。
  • 最后,DALL-E 2重申了Transformer作为基于网络规模数据集训练的模型中的最高地位,因为Transformer的并行性令人印象十分深刻。

ELECTRA: 超越BERT, 19年最佳NLP预训练模型

ELECTRA的全称是Efficiently Learning an Encoder that Classifies Token Replacements Accurately

Github: https://github.com/ymcui/Chinese-ELECTRA

ELECTRA : https://arxiv.org/abs/2003.10555

右边的图是左边的放大版,纵轴是GLUE分数,横轴是FLOPs (floating point operations),Tensorflow中提供的浮点数计算量统计。从上图可以看到,同等量级的ELECTRA是一直碾压BERT的,而且在训练更长的步数之后,达到了当时的SOTA模型——RoBERTa的效果。从左图曲线上也可以看到,ELECTRA效果还有继续上升的空间

2. 模型结构

NLP式的Generator-Discriminator

ELECTRA最主要的贡献是提出了新的预训练任务和框架,把生成式的Masked language model(MLM)预训练任务改成了判别式的Replaced token detection(RTD)任务,判断当前token是否被语言模型替换过。那么问题来了,我随机替换一些输入中的字词,再让BERT去预测是否替换过可以吗?可以的,因为我就这么做过,但效果并不好,因为随机替换太简单了

那怎样使任务复杂化呢?。。。咦,咱们不是有预训练一个MLM模型吗?

于是作者就干脆使用一个MLM的G-BERT来对输入句子进行更改,然后丢给D-BERT去判断哪个字被改过,如下:

于是,我们NLPer终于成功地把CV的GAN拿过来了!

Replaced Token Detection

但上述结构有个问题,输入句子经过生成器,输出改写过的句子,因为句子的字词是离散的,所以梯度在这里就断了,判别器的梯度无法传给生成器,于是生成器的训练目标还是MLM(作者在后文也验证了这种方法更好),判别器的目标是序列标注(判断每个token是真是假),两者同时训练,但判别器的梯度不会传给生成器,目标函数如下:

因为判别器的任务相对来说容易些,RTD loss相对MLM loss会很小,因此加上一个系数,作者训练时使用了50。

另外要注意的一点是,在优化判别器时计算了所有token上的loss,而以往计算BERT的MLM loss时会忽略没被mask的token。作者在后来的实验中也验证了在所有token上进行loss计算会提升效率和效果。

事实上,ELECTRA使用的Generator-Discriminator架构与GAN还是有不少差别,作者列出了如下几点:

3. 实验及结论

创新总是不易的,有了上述思想之后,可以看到作者进行了大量的实验,来验证模型结构、参数、训练方式的效果。

Weight Sharing

生成器和判别器的权重共享是否可以提升效果呢?作者设置了相同大小的生成器和判别器,在不共享权重下的效果是83.6,只共享token embedding层的效果是84.3,共享所有权重的效果是84.4。作者认为生成器对embedding有更好的学习能力,因为在计算MLM时,softmax是建立在所有vocab上的,之后反向传播时会更新所有embedding,而判别器只会更新输入的token embedding。最后作者只使用了embedding sharing。

Smaller Generators

从权重共享的实验中看到,生成器和判别器只需要共享embedding的权重就足矣了,那这样的话是否可以缩小生成器的尺寸进行训练效率提升呢?作者在保持原有hidden size的设置下减少了层数,得到了下图所示的关系图:

可以看到,生成器的大小在判别器的1/4到1/2之间效果是最好的。作者认为原因是过强的生成器会增大判别器的难度(判别器:小一点吧,我太难了)。

Training Algorithms

实际上除了MLM loss,作者也尝试了另外两种训练策略:

  1. Adversarial Contrastive Estimation:ELECTRA因为上述一些问题无法使用GAN,但也可以以一种对抗学习的思想来训练。作者将生成器的目标函数由最小化MLM loss换成了最大化判别器在被替换token上的RTD loss。但还有一个问题,就是新的生成器loss无法用梯度上升更新生成器,于是作者用强化学习Policy Gradient的思想,最终优化下来生成器在MLM任务上可以达到54%的准确率,而之前MLE优化下可以达到65%。(感谢 @阿雪我要 勘误)
  2. Two-stage training:即先训练生成器,然后freeze掉,用生成器的权重初始化判别器,再接着训练相同步数的判别器。

对比三种训练策略,得到下图:

可见“隔离式”的训练策略效果还是最好的,而两段式的训练虽然弱一些,作者猜测是生成器太强了导致判别任务难度增大,但最终效果也比BERT本身要强,进一步证明了判别式预训练的效果。

Small model? Big model?

这两节真是吊打之前的模型,作者重申了他的主要目的是提升预训练效率,于是做了GPU单卡就可以愉快训练的ELECTRA-Small和BERT-Small,接着和尺寸不变的ELMo、GPT等进行对比,结果如下:

数据简直优秀,仅用14M参数量,以前13%的体积,在提升了训练速度的同时还提升了效果,这里我疯狂点赞。

小ELECTRA的本事我们见过了,那大ELECTRA行吗?直接上图:

上面是各个模型在GLUE dev/text上的表现,可以看到ELECTRA仅用了1/4的计算量就达到了RoBERTa的效果。而且作者使用的是XLNet的语料,大约是126G,但RoBERTa用了160G。由于时间和精力问题,作者们没有把ELECTRA训练更久(应该会有提升),也没有使用各种榜单Trick,所以真正的GLUE test上表现一般(现在的T5是89.7,RoBERTa是88.5,没看到ELECTRA)。

Efficiency Analysis

前文中提到了,BERT的loss只计算被替换的15%个token,而ELECTRA是全部都计算的,所以作者又做了几个实验,探究哪种方式更好一些:

  1. ELECTRA 15%:让判别器只计算15% token上的损失
  2. Replace MLM:训练BERT MLM,输入不用[MASK]进行替换,而是其他生成器。这样可以消除这种pretrain-finetune直接的diff。
  3. All-Tokens MLM:接着用Replace MLM,只不过BERT的目标函数变为预测所有的token,比较接近ELECTRA。

三种实验结果如下:

可以看到:

  1. 对比ELECTRA和ELECTRA 15%:在所有token上计算loss确实能提升效果
  2. 对比Replace MLM和BERT:[MASK]标志确实会对BERT产生影响,而且BERT目前还有一个trick,就是被替换的10%情况下使用原token或其他token,如果没有这个trick估计效果会差一些。
  3. 对比All-Tokens MLM和BERT:如果BERT预测所有token 的话,效果会接近ELECTRA

另外,作者还发现,ELECTRA体积越小,相比于BERT就提升的越明显,说明fully trained的ELECTRA效果会更好。另外作者推断,由于ELECTRA是判别式任务,不用对整个数据分布建模,所以更parameter-efficient

4. 总结

无意中发现了这篇还在ICLR盲审的ELECTRA,读完摘要就觉得发现了新大陆,主要是自己也试过Replaced Token Detection这个任务,因为平时任务效果的分析和不久前看的一篇文章,让我深刻感受到了BERT虽然对上下文有很强的编码能力,却缺乏细粒度语义的表示,我用一张图表示大家就明白了:

这是把token编码降维后的效果,可以看到sky和sea明明是天与海的区别,却因为上下文一样而得到了极为相似的编码。细粒度表示能力的缺失会对真实任务造成很大影响,如果被针对性攻击的话更是无力,所以当时就想办法加上更细粒度的任务让BERT去区分每个token,不过同句内随机替换的效果并不好。相信这个任务很多人都想到过,不过都没有探索这么深入,这也告诫我们,idea遍地都是,往下挖才能有SOTA。

ELECTRA是BERT推出这一年来我见过最赞的idea,它不仅提出了能打败MLM的预训练任务,更推出了一种十分适用于NLP的类GAN框架。毕竟GAN太牛逼了,看到deepfake的时候我就想,什么时候我们也能deepcheat,但听说GAN在NLP上的效果一直不太好,这次ELECTRA虽然只用了判别器,但个人认为也在一定程度上打开了潘多拉魔盒。

算法工程师面试知识点整理:

https://mp.weixin.qq.com/s/nPVbgOBOPs5VjW6_U-Om3w