Swin-Unet:Unet形状的纯Transformer的医学图像分割

首个基于纯Transformer的U-Net形的医学图像分割网络,其中利用Swin Transformer构建编码器、bottleneck和解码器,表现SOTA!性能优于TransUnet、Att-UNet等

单位:慕尼黑工业大学, 复旦大学, 华为(田奇等人)
代码:https://github.com/HuCaoFighting/Swin-Unet
论文下载链接:https://arxiv.org/abs/2105.0553

在过去的几年中,卷积神经网络(CNN)在医学图像分析中取得了里程碑式的进展。尤其是,基于U形架构和跳跃连接的深度神经网络已广泛应用于各种医学图像任务中。但是,尽管CNN取得了出色的性能,但是由于卷积操作的局限性,它无法很好地学习全局和远程语义信息交互。

在本文中,我们提出了Swin-Unet,它是用于医学图像分割的类似Unet的纯Transformer。标记化的图像块通过跳跃连接被馈送到基于Transformer的U形En-Decoder架构中,以进行局部全局语义特征学习。

具体来说,我们使用带有偏移窗口的分层Swin Transformer作为编码器来提取上下文特征。

实验结果

在对输入和输出进行4倍的直接下采样和上采样的情况下,对多器官和心脏分割任务进行的实验表明,基于纯Transformer的U形编码器/解码器网络优于那些全卷积或者Transformer和卷积的组合。

个人感觉,这个transformer架构下的参数量应该会大很多吧(近百M)?目前作者好像并未在论文中给出对比。另外就是这个 感觉没啥创新点..感觉没啥有点水。

UNet网络

论文: https://arxiv.org/abs/1505.04597

FCN虽然做出了开创性的工作,FCN-8s相较于此前的SOTA分割表现,已经取得了巨大的优势。但从分割效果上看还很粗糙,对图像的细节处理还很不成熟,也没有考虑到像素与像素之间的上下文(context)关系,所以FCN更像是一项抛砖引玉式的工作,随着U形的编解码结构成为通用的语义分割网络设计范式,各种网络如雨后春笋般涌现。UNet是U形网络结构最经典和最主要的代表网络,因其网络结构是一个U形而得名,这类编解码的结构也因而被称之为U形结构。提出UNet的论文为U-Net: Convolutional Networks for Biomedical Image Segmentation,与FCN提出时间相差了两个月,其结构设计在FCN基础上做了进一步的改进,设计初衷主要是用于医学图像的分割。截至到本书写稿,UNet在谷歌学术上的引用次数已达44772次,堪称深度学习语义分割领域的里程碑式的工作。

1、与FCN区别

U-Net和FCN非常的相似,U-Net比FCN稍晚提出来,但都发表在2015年,和FCN相比,U-Net的第一个特点是完全对称,也就是左边和右边是很类似的,而FCN的decoder相对简单,只用了一个deconvolution的操作,之后并没有跟上卷积结构。第二个区别就是skip connection,FCN用的是加操作(summation),U-Net用的是叠操作(concatenation)。这些都是细节,重点是它们的结构用了一个比较经典的思路,也就是编码和解码(encoder-decoder),早在2006年就被Hinton大神提出来发表在了nature上.

当时这个结构提出的主要作用并不是分割,而是压缩图像和去噪声。输入是一幅图,经过下采样的编码,得到一串比原先图像更小的特征,相当于压缩,然后再经过一个解码,理想状况就是能还原到原来的图像。这样的话我们存一幅图的时候就只需要存一个特征和一个解码器即可。这个想法我个人认为是很漂亮了。同理,这个思路也可以用在原图像去噪,做法就是在训练的阶段在原图人为的加上噪声,然后放到这个编码解码器中,目标是可以还原得到原图。

后来把这个思路被用在了图像分割的问题上,也就是现在我们看到的U-Net结构,在它被提出的三年中,有很多很多的论文去讲如何改进U-Net或者FCN,不过这个分割网络的本质的拓扑结构是没有改动的。举例来说,去年ICCV上凯明大神提出的Mask RCNN. 相当于一个检测,分类,分割的集大成者,我们仔细去看它的分割部分,其实使用的也就是这个简单的FCN结构。说明了这种“U形”的编码解码结构确实非常的简洁,并且最关键的一点是好用。

2、为什么有效

相比于FCN和Deeplab等,UNet共进行了4次上采样,并在同一个stage使用了skip connection,而不是直接在高级语义特征上进行监督和loss反传,这样就保证了最后恢复出来的特征图融合了更多的low-level的feature,也使得不同scale的feature得到了的融合,从而可以进行多尺度预测和DeepSupervision。4次上采样也使得分割图恢复边缘等信息更加精细。

其次我们聊聊【医疗影像】,医疗影像有什么样的特点呢(尤其是相对于自然影像而言)?

1.图像语义较为简单、结构较为固定。我们做脑的,就用脑CT和脑MRI,做胸片的只用胸片CT,做眼底的只用眼底OCT,都是一个固定的器官的成像,而不是全身的。由于器官本身结构固定和语义信息没有特别丰富,所以高级语义信息和低级特征都显得很重要(UNet的skip connection和U型结构就派上了用场)。

2.数据量少。医学影像的数据获取相对难一些,很多比赛只提供不到100例数据。所以我们设计的模型不宜多大,参数过多,很容易导致过拟合。

原始UNet的参数量在28M左右(上采样带转置卷积的UNet参数量在31M左右),而如果把channel数成倍缩小,模型可以更小。缩小两倍后,UNet参数量在7.75M。缩小四倍,可以把模型参数量缩小至2M以内,非常轻量。个人尝试过使用Deeplab v3+和DRN等自然图像语义分割的SOTA网络在自己的项目上,发现效果和UNet差不多,但是参数量会大很多。

为什么适用于医学图像?

(1)因为医学图像边界模糊、梯度复杂,需要较多的高分辨率信息。高分辨率用于精准分割。

(2)人体内部结构相对固定,分割目标在人体图像中的分布很具有规律,语义简单明确,低分辨率信息能够提供这一信息,用于目标物体的识别。

UNet结合了低分辨率信息(提供物体类别识别依据)和高分辨率信息(提供精准分割定位依据),完美适用于医学图像分割。

网络结构

在医学图像领域,具体到更加细分的医学图像识别任务时,大量的带有高质量标注的图像数据十分难得,在此之前的通常做法是采用滑动窗口卷积(类似于图像分块)的方式来进行图像局部预测,这么做的好处是可以做图像像素做到一定程度定位,其次就是滑窗分块能够使得训练样本量增多。但缺点也很明显,一个是滑窗操作非常耗时,推理的时候效率低下,其次就是不能兼顾定位精度和像素上下文信息的利用率。UNet在FCN的基础上,完整地给出了U形的编解码结构,如下图所示

输入是一幅图,输出是目标的分割结果。继续简化就是,一幅图,编码,或者说降采样,然后解码,也就是升采样,然后输出一个分割结果。根据结果和真实分割的差异,反向传播来训练这个分割网络。我们可以说,U-Net里面最精彩的部分就是这三部分:

  • 下采样
  • 上采样
  • skip connection

UNet结构包括编码器下采样、解码器上采样和同层跳跃连接三个组成部分。编码器由4组卷积、ReLU激活和最大池化构成,每一组均有两次3*3的卷积,每个卷积层后面都有一次ReLU激活函数,然后再进行一次步长为2的2*2最大池化进行下采样,如第一组操作输入图像大小为572*572,两轮3*3的卷积之后的特征图大小为568*568,再经过22最大池化后的输出尺寸为284*284。解码器由4组2*2转置卷积、3*3卷积构成和一个ReLU激活函数构成,在最后的输出层又补充了一个1*1卷积。最后是同层跳跃连接,这也是UNet的特色操作之一,指的是将下采样时每一层的输出裁剪后连接到同层的上采样层做融合。每一次下采样都会有一个跳跃连接与对应的上采样进行融合,这种不同尺度的特征融合对上采样恢复像素大有帮助,具体来说就是高层(浅层)下采样倍数小,特征图具备更加细致的图特征,低层(深层)下采样倍数大,信息经过大量浓缩,空间损失大,但有助于目标区域(分类)判断,当高层和低层的特征进行融合时,分割效果往往会非常好。从某种程度上讲,这种跳跃连接也可以视为一种深度监督。

我们将UNet结构按照编码器、解码器和同层跳跃连接进行简化,如下图所示。编码器下采样用于特征提取和语义信息浓缩,解码器上采样用于图像像素恢复,跳跃连接则用于信息补充。自此,基于U形结构的编解码设计成为深度学习语义分割中的奠基性的网络结构,经过近几年的发展,语义分割虽然取得了长足的进步,但UNet和编解码结构一直是新的模型设计的参照对象。

代码实现:

# 导入PyTorch相关模块
import torch
import torch.nn as nn
import torch.nn.functional as F

### 编码块
class UNetEnc(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False):
        super().__init__()
    # 每一个编码块中的结构
        layers = [
            nn.Conv2d(in_channels, out_channels, 3, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, dilation=2),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers += [nn.Dropout(.5)]
        layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]
        self.down = nn.Sequential(*layers)
  # 编码块前向计算流程
    def forward(self, x):
        return self.down(x)

### 解码块    
class UNetDec(nn.Module):
    def __init__(self, in_channels, features, out_channels):
        super().__init__()
    # 每一个解码块中的结构
        self.up = nn.Sequential(
            nn.Conv2d(in_channels, features, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, 3),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(features, out_channels, 2, stride=2),
            nn.ReLU(inplace=True),
        )
  # 解码块前向计算流程
    def forward(self, x):
        return self.up(x)

### 基于编解码的U-Net
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
    # 四个编码块
        self.enc1 = UNetEnc(3, 64)
        self.enc2 = UNetEnc(64, 128)
        self.enc3 = UNetEnc(128, 256)
        self.enc4 = UNetEnc(256, 512, dropout=True)
    # 中间部分(U形底部)
        self.center = nn.Sequential(
            nn.Conv2d(512, 1024, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, 3),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.ConvTranspose2d(1024, 512, 2, stride=2),
            nn.ReLU(inplace=True),
        )
    # 四个解码块
        self.dec4 = UNetDec(1024, 512, 256)
        self.dec3 = UNetDec(512, 256, 128)
        self.dec2 = UNetDec(256, 128, 64)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(inplace=True),
        )
        self.final = nn.Conv2d(64, num_classes, 1)

    # 前向传播过程
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        center = self.center(enc4)
        # 包含了同层分辨率级联的解码块
        dec4 = self.dec4(torch.cat([
            center, F.upsample_bilinear(enc4, center.size()[2:])], 1))
        dec3 = self.dec3(torch.cat([
            dec4, F.upsample_bilinear(enc3, dec4.size()[2:])], 1))
        dec2 = self.dec2(torch.cat([
            dec3, F.upsample_bilinear(enc2, dec3.size()[2:])], 1))
        dec1 = self.dec1(torch.cat([
            dec2, F.upsample_bilinear(enc1, dec2.size()[2:])], 1))
        return F.upsample_bilinear(self.final(dec1), x.size()[2:])

Unet论文合集(待更新)–医学图像

自2015年以来,UNET在医学图像细分中取得了重大突破,开放了深度学习时代。后来的研究人员在UNET的基础上做出了很多改进,以提高语义细分的性能。

摘自:https://github.com/ShawnBIT/UNet-family

如何查找代码:在 https://paperswithcode.com/ 查找论文即可

UNet-family

2015

  • U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI) [paper] [my-pytorch][keras]

2016

  • V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation [paper] [caffe][pytorch]
  • 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation [paper][pytorch]

2017

  • H-DenseUNet: Hybrid Densely Connected UNet for Liver and Tumor Segmentation from CT Volumes (IEEE Transactions on Medical Imaging)[paper][keras]
  • GP-Unet: Lesion Detection from Weak Labels with a 3D Regression Network (MICCAI) [paper]

2018

  • UNet++: A Nested U-Net Architecture for Medical Image Segmentation (MICCAI) [paper][my-pytorch][keras]
  • MDU-Net: Multi-scale Densely Connected U-Net for biomedical image segmentation [paper]
  • DUNet: A deformable network for retinal vessel segmentation [paper]
  • RA-UNet: A hybrid deep attention-aware network to extract liver and tumor in CT scans [paper]
  • Dense Multi-path U-Net for Ischemic Stroke Lesion Segmentation in Multiple Image Modalities [paper]
  • Stacked Dense U-Nets with Dual Transformers for Robust Face Alignment [paper]
  • Prostate Segmentation using 2D Bridged U-net [paper]
  • nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation [paper][pytorch]
  • SUNet: a deep learning architecture for acute stroke lesion segmentation and outcome prediction in multimodal MRI [paper]
  • IVD-Net: Intervertebral disc localization and segmentation in MRI with a multi-modal UNet [paper]
  • LADDERNET: Multi-Path Networks Based on U-Net for Medical Image Segmentation [paper][pytorch]
  • Glioma Segmentation with Cascaded Unet [paper]
  • Attention U-Net: Learning Where to Look for the Pancreas [paper]
  • Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation [paper]
  • Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks [paper]
  • A Probabilistic U-Net for Segmentation of Ambiguous Images (NIPS) [paper] [tensorflow]
  • AnatomyNet: Deep Learning for Fast and Fully Automated Whole-volume Segmentation of Head and Neck Anatomy [paper]
  • 3D RoI-aware U-Net for Accurate and Efficient Colorectal Cancer Segmentation [paper][pytorch]
  • Detection and Delineation of Acute Cerebral Infarct on DWI Using Weakly Supervised Machine Learning (Y-Net) (MICCAI) [paper](Page 82)
  • Fully Dense UNet for 2D Sparse Photoacoustic Tomography Artifact Removal [paper]

2019

  • MultiResUNet : Rethinking the U-Net Architecture for Multimodal Biomedical Image Segmentation [paper][keras]
  • U-NetPlus: A Modified Encoder-Decoder U-Net Architecture for Semantic and Instance Segmentation of Surgical Instrument [paper]
  • Probability Map Guided Bi-directional Recurrent UNet for Pancreas Segmentation [paper]
  • CE-Net: Context Encoder Network for 2D Medical Image Segmentation [paper][pytorch]
  • Graph U-Net [paper]
  • A Novel Focal Tversky Loss Function with Improved Attention U-Net for Lesion Segmentation (ISBI) [paper]
  • ST-UNet: A Spatio-Temporal U-Network for Graph-structured Time Series Modeling [paper]
  • Connection Sensitive Attention U-NET for Accurate Retinal Vessel Segmentation [paper]
  • CIA-Net: Robust Nuclei Instance Segmentation with Contour-aware Information Aggregation [paper]
  • W-Net: Reinforced U-Net for Density Map Estimation [paper]
  • Automated Segmentation of Pulmonary Lobes using Coordination-guided Deep Neural Networks (ISBI oral) [paper]
  • U2-Net: A Bayesian U-Net Model with Epistemic Uncertainty Feedback for Photoreceptor Layer Segmentation in Pathological OCT Scans [paper]
  • ScleraSegNet: an Improved U-Net Model with Attention for Accurate Sclera Segmentation (ICB Honorable Mention Paper Award) [paper]
  • AHCNet: An Application of Attention Mechanism and Hybrid Connection for Liver Tumor Segmentation in CT Volumes [paper]
  • A Hierarchical Probabilistic U-Net for Modeling Multi-Scale Ambiguities [paper]
  • Recurrent U-Net for Resource-Constrained Segmentation [paper]
  • MFP-Unet: A Novel Deep Learning Based Approach for Left Ventricle Segmentation in Echocardiography [paper]
  • A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation (MICCAI 2019) [paper][pytorch]
  • ResUNet-a: a deep learning framework for semantic segmentation of remotely sensed data [paper]
  • A multi-task U-net for segmentation with lazy labels [paper]
  • RAUNet: Residual Attention U-Net for Semantic Segmentation of Cataract Surgical Instruments [paper]
  • 3D U2-Net: A 3D Universal U-Net for Multi-Domain Medical Image Segmentation (MICCAI 2019) [paper] [pytorch]
  • SegNAS3D: Network Architecture Search with Derivative-Free Global Optimization for 3D Image Segmentation (MICCAI 2019) [paper]
  • 3D Dilated Multi-Fiber Network for Real-time Brain Tumor Segmentation in MRI [paper][pytorch] (MICCAI 2019)
  • The Domain Shift Problem of Medical Image Segmentation and Vendor-Adaptation by Unet-GAN [paper]
  • Recurrent U-Net for Resource-Constrained Segmentation [paper] (ICCV 2019)
  • Siamese U-Net with Healthy Template for Accurate Segmentation of Intracranial Hemorrhage (MICCAI 2019)

2020

  • U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection (Pattern Recognition 2020) [paper][pytorch]
  • UNET 3+: A Full-Scale Connected UNet for Medical Image Segmentation (ICASSP 2020) [paper][pytorch]

2021

  • TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation [paper][pytorch]
  • Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation [paper][pytorch]
  • UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer [paper][pytorch]

FCN全卷积网络–图像分割的开山之作

论文地址: https://arxiv.org/abs/1411.4038

随着CNN在图像识别中取得巨大成功,一些经典的图像分类网络(AlexNet、VGG、GoogLeNet、ResNet)也逐渐被应用于更加细分的视觉任务中。很多研究者也在探索如何将分类网络进行改造后用于语义分割的密集预测问题(dense predictions)。在更高效的语义分割网络提出之前,学术界用于密集预测任务的模型主要有以下几个特点:

(1)小模型。早期的网络结构受限于数据量和高性能的计算资源,在设计上一般不会使用过大的模型。

(2)分块训练。分块训练(patchwise training)在当时是图像训练的普遍做法,但该方法对于全卷积网络的训练会显得相对低效,但分块训练的优点在于能够规避类别不均衡问题,并且能够缓解密集分块的空间相关性问题。

(3)输入移位与输出交错。该方法可以视为一种输入与输出的变换方法,在OverFeat等结构中被广泛使用。

(4)后处理。对于神经网络输出质量不高的问题,对输出加后处理也是常见做法,常用的后处理方法包括超像素投影(superpixel projection)、随机场正则化(random field regularization)和图像滤波处理等。

可以看到,早期用于目标检测、关键点预测和语义分割等密集预测问题整体来看有两个缺陷,一是无法实现端到端(end-to-end)的流程,模型整体效率不佳;第二个则是不能做到真正的密集预测的特征:像素到像素(pixels-to-pixels)的预测。

全卷积网络(Fully Convolutional Networks, FCN)的提出,正好可以解决早期网络结构普遍存在的上述两个缺陷。FCN在2015年的一篇论文Fully Convolutional Networks for Semantic Segmentation中提出,其主要思路在于用卷积层代替此前分类网络中的全连接层,将全连接层的语义标签输出改为卷积层的语义热图(heatmap)输出,再结合上采样技术实现像素到像素的密集预测。如下图所示,上图为常见分类网络的流程,在五层卷积网络之后有三层全连接网络,最后输出一个包含类别语义信息的输出概率;下图为FCN网络流程,在上图分类网络的基础上,将最后三层全连接层改为卷积层,输出也相应的变为分类预测的热图,这样就为了最后的像素级的密集预测提供了基础。

所以,FCN实现密集预测的关键在于修改全连接层为卷积层,那么具体是如何修改的呢?先来详细分析一下的卷积层和全连接层的特征。卷积层与全连接层最大的区别在于卷积层每次计算时只与输入图像中一个具体的局部做运算,但二者都是做点积计算,其函数形式是类似的。假设给定在指定网络层任意坐标点(i,j)的数据向量Xij,而下一层对应坐标点的数据向量为Yij,有:

其中为卷积核大小或者权重向量长度,s为步长(stride),而f_ks则表示当前层到下一层的映射函数,f_ks既可以表示为卷积层又可以表示为全连接层,所以二者之间的转换是有理论基础的。

FCN分别在AlexNet、VGG和GoogLeNet上进行了全连接层转卷积层的修改,通过实验发现以VGG16作为主干网络效果最好,完整的FCN结构如下图所示,第一行最左边为原始输入图像,图像尺寸为32×32,conv为卷积层,pool为池化层,可以注意到conv6-7是最后的卷积层,此时得到的密集预测热图尺寸为输入图像的1/32,为了实现像素到像素的预测,还需要对热图进行上采样,FCN采用双线性插值(bilinear interpolation)进行上采样,所以这里需要将热图上采样32倍来恢复到原始图像的尺寸,因而第一行的网络结构也叫FCN-32s。直接进行32倍上采样得到的输出无疑是较为粗糙的,为了提高像素预测质量,FCN又分别有FCN-16s和FCN-8s的改进版本。图中第二行即为FCN-16s,主要区别在于先将conv7(1×1)的输出热图进行2倍上采样,然后将其与pool4(2×2)进行融合,最后对融合后的结果进行16倍上采样得到最终预测结果,同理FCN-8s将pool3(4×4)、2倍上采样后的pool4(4×4)以及4倍上采样的conv7(4×4)进行融合,最后再进行8倍的上采样得到语义分割图像。

所以,从FCN-32s到FCN-8s,其实一种粗分割到精细分割的演变过程,FCN通过融合浅层图像特征和深层卷积热图的方式来得到当时的SOTA(State of the art)水平的语义分割模型。下图是FCN-32s、FCN-16s和FCN-8s在同一张图像上的分割效果,与分割的标准图像(Ground truth)相比,可以看到三个模型的分割精度是在不断优化的。

下方代码给出FCN-8s的一个PyTorch简略实现方式,便于读者加深对FCN的理解。代码中对于卷积下采样使用了VGG16的预训练权重,分别构建了四个特征提取模块、一个卷积块和三个独立的卷积层。在前向传播流程中,将conv7、pool3和pool4进行融合,最后再做8倍的双线性插值上采样。

# 导入PyTorch相关模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

### 定义FCN-8s模型类
class FCN8(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # 提取VGG16预训练权重作为特征
        feats =list(models.vgg16(pretrained=True).features.children())
        # 取前9层为第一特征模块
        self.feat1 = nn.Sequential(*feats[0:9])
        # 取第10-15层为第二特征模块
        self.feat2 = nn.Sequential(*feats[10:16])
        # 取第16-22层为第三特征模块
        self.feat3 = nn.Sequential(*feats[17:23])
        # 取后6层为第四特征模块
        self.feat4 = nn.Sequential(*feats[24:30])
        # 卷积层权重不参与训练更新
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.requires_grad = False
        # 定义卷积块
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(512, 4096, 7),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Conv2d(4096, 4096, 1),
            nn.ReLU(inplace=True),
            nn.Dropout(),
        )
        # 改最后三层的全连接层为卷积层
        self.conv1 = nn.Conv2d(256, num_classes, 1)
        self.conv2 = nn.Conv2d(512, num_classes, 1)
        self.conv3 = nn.Conv2d(4096, num_classes, 1)

    ### 定义前向计算流程
    def forward(self, x):
        feat1 = self.feat1(x)
        feat2 = self.feat2(feat1)
        feat3 = self.feat3(feat2)
        feat4 = self.feat4(feat3)
        conv_blocks = self.conv_blocks(feat4)

        conv1 = self.conv1(feat2)
        conv2 = self.conv2(feat3)
        conv3 = self.conv3(conv_blocks)      
        outputs = F.upsample_bilinear(conv_blocks, conv2.size()[2:])
        # 第一次融合
        outputs += conv2
        outputs = F.upsample_bilinear(outputs, conv1.size()[2:])
        # 第二次融合
        outputs += conv1
        return F.upsample_bilinear(outputs, x.size()[2:]) 

FCN是深度学习语义分割网络的开山之作,在结构设计上率先将全卷积网络用于深度学习语义分割任务,在经典分类网络的基础上实现了像素到像素和端到端的分割。FCN整体上已具备编解码架构的U形网络雏形,为后续的网络设计开创了坚实的基础。

SUNet: Swin Transformer with UNet for Image Denoising

ISCAS 2022的一篇文章,作为首个Swin Transformer在图像去噪领域的应用,效果来说感觉还有很大提高空间。但不的不说,自从Swin Transformer(2021)提出后,在整个cv领域独领风骚。作为一个通用的架构,可以将其应用在各个cv领域,从paperwithcode里就可以见其影响力:(截止到22.8.28)

1、目标检测:

2、图像超分辨率

3、实例分割:

4、3D医学图像分割:

今天,就来看看Swin Transformer 对于图像去噪任务的处理效果:

个人觉得 Swin Transformer 对于去噪来说还有很大的扩展空间,这篇论文的模型效果不是很好,可以值得去尝试尝试,看看有没有更好的方法提高模型效果。

论文的主要贡献:

1、结合Unet网络+ Swin Transformer

2、提出了一个双上采样模块 dual up-sample block

3、首个将Swin +unet用于图像去噪领域

4、在 两个通用数据集中测试的结果还不错

网络结构:

网络分为三个部分:1)Shallow feature extraction; 2) UNet feature extraction; and
3) Reconstruction module

1、Shallow feature extraction

使用3*3卷积,提取特征,输出通道96

2、 UNet feature extraction

带有 Swin Transformer Block 的UNET体系结构,其中包含8个 Swin Transformer 层,以取代卷积。
Swin Transformer Block(STB)和Swin Transformer层(STL):

STB:包含8个STL

这块建议去看 Swin Transformer 论文,讲的比较清楚。注意此时的输入输出大小完全一致,因此需要下采样。

下采样: Patch merging

Patch merging:通过查看Patch merging的源码,可以看到,其实就是一个下采样的过程,它可以看成一种加权池化的过程。实现维度下采样、特征加倍的效果。

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

上采样:Dual up-sample

作者提出了 上采样,

该模块包括两种现有的上样本方法(即双线性和PixelShuffle),以防止棋盘伪影(Deconvolution and Checkerboard Artifacts中提出的)https://distill.pub/2016/deconv-checkerboard/ 产生原因:主要是出现在反卷积中。

上采样模块

通过两种上采样后,cat维度拼接后,通过一个卷积层将维度减半C/2

实验:

如上图所示。

中文文本清洗与特征提取

摘自知乎:

bookname嵌入式AI算法研究

中文文本清洗

中文文本清洗:

– 去除指定无用的符号

– 让文本只保留汉字

– 文本中的表情符号去除

– 繁体中文与简体中文转换

中文文本清洗类

import re
from opencc import OpenCC
from bs4 import BeautifulSoup
import jieba
from glob import glob

import torch
from tqdm.auto import tqdm

import sys
!ls ../package/
sys.path.insert(0, "../package/")
from ltp import LTP
nlp = LTP(path="base")

class TextCleaner:
    '''
        批量清洗数据
    '''
    def __init__(self,
                 remove_space=True, # 去除空格
                 remove_suspension=True, # 转换省略号
                 only_zh=False, # 只保留汉子
                 remove_sentiment_character=True, # 去除表情符号
                 to_simple=True, # 转化为简体中文
                 remove_html_label=True,
                 remove_stop_words=False,
                 stop_words_dir="./停用词/",
                 with_space=False,
                 batch_size=256):
        self._remove_space = remove_space
        self._remove_suspension = remove_suspension
        self._remove_sentiment_character = remove_sentiment_character

        self._only_zh = only_zh
        self._to_simple = to_simple

        self._remove_html_label = remove_html_label
        self._remove_stop_words = remove_stop_words
        self._stop_words_dir = stop_words_dir

        self._with_space = with_space
        self._batch_size = batch_size

    def clean_single_text(self, text):
        if self._remove_space:
            text = self.remove_space(text)
        if self._remove_suspension:
            text = self.remove_suspension(text)
        if self._remove_sentiment_character:
            text = self.remove_sentiment_character(text)
        if self._to_simple:
            text = self.to_simple(text)
        if self._only_zh:
            text = self.get_zh_only(text)
        if self._remove_html_label:
            text = self.remove_html(text)
        return text

    def clean_text(self, text_list):
        text_list = [self.clean_single_text(text) for text in tqdm(text_list)]
        tokenized_words_list = self.tokenizer_batch_text(text_list)
        if self._remove_stop_words:
            text_list = [self.remove_stop_words(words_list, self._stop_words_dir, self._with_space) for words_list in tokenized_words_list]
        return text_list

    def remove_space(self, text):     #定义函数
        return text.replace(' ','')   # 去掉文本中的空格

    def remove_suspension(self, text):
        return text.replace('...', '。')

    def get_zh_only(self, text):
        def is_chinese(uchar):
            if uchar >= u'\u4e00' and uchar <= u'\u9fa5':  # 判断一个uchar是否是汉字 中文字符的编码范围 \u4e00 - \u9fff,只要在这个范围就可以
                return True
            else:
                return False

        content = ''
        for i in text:
            if is_chinese(i):
                content = content+i
        return content

    def remove_sentiment_character(self, sentence):    
        pattern = re.compile("[^\u4e00-\u9fa5^,^.^!^,^。^?^?^!^a-z^A-Z^0-9]")  #只保留中英文、数字和符号,去掉其他东西
        #若只保留中英文和数字,则替换为[^\u4e00-\u9fa5^a-z^A-Z^0-9]
        line = re.sub(pattern,'',sentence)  #把文本中匹配到的字符替换成空字符
        new_sentence=''.join(line.split())    #去除空白
        return new_sentence

    def to_simple(self, sentence):
        new_sentence = OpenCC('t2s').convert(sentence)   # 繁体转为简体
        return new_sentence

    def to_tradition(self, sentence):
        new_sentence = OpenCC('s2t').convert(sentence)   # 简体转为繁体
        return new_sentence

    def remove_html(self, text):
        return BeautifulSoup(text, 'html.parser').get_text() #去掉html标签

    def tokenizer_batch_text(self, text_list):
        tokenized_text = []
        len_text = len(text_list)
        with torch.no_grad():
            steps = self._batch_size
            for start_idx in tqdm(range(0, len_text, steps)):
                if start_idx + steps > len_text:
                    tokenized_text += nlp.seg(text_list[start_idx:])[0]
                else:
                    tokenized_text += nlp.seg(text_list[start_idx:start_idx+steps])[0]
        return tokenized_text

    def remove_stop_words(self, words_list, stop_words_dir, with_space=False):
        """
        中文数据清洗  stopwords_chineses.txt存放在博客园文件中
        :param text:
        :return:
        """
        stop_word_filepath_list = glob(stop_words_dir + "/*.txt")
        for stop_word_filepath in stop_word_filepath_list:
            with open(stop_word_filepath) as fp:
                stopwords = {}.fromkeys([line.rstrip() for line in fp]) #加载停用词(中文)
        eng_stopwords = set(stopwords) #去掉重复的词
        words = [w for w in words_list if w not in eng_stopwords] #去除文本中的停用词
        if with_space:
            return ' '.join(words)
        else:
            return ''.join(words)
ltp


file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
cleaner = TextCleaner(remove_stop_words=True, with_space=True)
contents = ['   大家好, 欢迎一起来学习文本的空格   去除   !', '   大家好,文本的空格   去除   !']
results = cleaner.clean_text(contents)
print(results)
0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]


['好 , 学习 文本 空格 去除 !', '好 , 文本 空格 去除 !']

去除空格

# 去除空格
contents = '   大家好, 欢迎一起来学习文本的空格   去除   !'
print('处理前文本:'+contents)
def process(our_data):     #定义函数
    content = our_data.replace(' ','')   # 去掉文本中的空格
    print('处理后文本:'+content)
process(contents)
处理前文本:   大家好, 欢迎一起来学习文本的空格   去除   !
处理后文本:大家好,欢迎一起来学习文本的空格去除!

去除空格的同时把省略号转换为句号

# 去除空格的同时把省略号转换为句号
contents = '   大家好, 这里还有  很多的知识...一起拉学习吧 !'
print('处理前文本:'+contents)
def process(data):     #定义函数
    content1 = data.replace(' ','')    # 去掉文本中的空格
    content2 = content1.replace('...','。')    # 去掉文本中的空格
    print('处理后文本:'+ content2)
process(contents)
处理前文本:   大家好, 这里还有  很多的知识...一起拉学习吧 !
处理后文本:大家好,这里还有很多的知识。一起拉学习吧!

让文本只保留汉字

def is_chinese(uchar):
    if uchar >= u'\u4e00' and uchar <= u'\u9fa5':  # 判断一个uchar是否是汉字
        return True
    else:
        return False

def allcontents(contents):
    content = ''
    for i in contents:
        if is_chinese(i):
            content = content+i
    print('\n处理后的句子为:\n'+content)

centents = '1,2,3...我们开始吧, 加油!'
print('原句子为:\n'+centents)
allcontents(centents)
原句子为:
1,2,3...我们开始吧, 加油!

处理后的句子为:
我们开始吧加油

文本中的表情符号去除

import re
sentence='现在听着音乐,duo rui mi,很开心*_*'
print('原句子为:\n'+sentence)

def clear_character(sentence):    
    pattern = re.compile("[^\u4e00-\u9fa5^,^.^!^a-z^A-Z^0-9]")  #只保留中英文、数字和符号,去掉其他东西
    #若只保留中英文和数字,则替换为[^\u4e00-\u9fa5^a-z^A-Z^0-9]
    line=re.sub(pattern,'',sentence)  #把文本中匹配到的字符替换成空字符
    new_sentence=''.join(line.split())    #去除空白
    print('\n处理后的句子为:\n'+new_sentence) 

clear_character(sentence)
原句子为:
现在听着音乐,duo rui mi,很开心*_*

处理后的句子为:
现在听着音乐,duoruimi,很开心

繁体中文与简体中文转换

from opencc import OpenCC

sentence = '你现在读的这里是简体,这里是繁体,能看懂吗?'
print('原句子为:\n'+sentence)

def Simplified(sentence):
    new_sentence = OpenCC('t2s').convert(sentence)   # 繁体转为简体
    print('\n处理后的句子为:\n'+new_sentence)

def Traditional(sentence):
    new_sentence = OpenCC('s2t').convert(sentence)   # 简体转为繁体
    print('\n处理后的句子为:\n'+new_sentence) 

Simplified(sentence)
Traditional(sentence)
原句子为:
你现在读的这里是简体,这里是繁体,能看懂吗?

处理后的句子为:
你现在读的这里是简体,这里是繁体,能看懂吗?

处理后的句子为:
你现在读的这里是简体,这里是繁体,能看懂吗?

OpenCC的参数设置:

- hk2s: Traditional Chinese (Hong Kong standard) to Simplified Chinese
- s2hk: Simplified Chinese to Traditional Chinese (Hong Kong standard)
- s2t: Simplified Chinese to Traditional Chinese
- s2tw: Simplified Chinese to Traditional Chinese (Taiwan standard)
- s2twp: Simplified Chinese to Traditional Chinese (Taiwan standard, with phrases)
- t2hk: Traditional Chinese to Traditional Chinese (Hong Kong standard)
- t2s: Traditional Chinese to Simplified Chinese
- t2tw: Traditional Chinese to Traditional Chinese (Taiwan standard)
- tw2s: Traditional Chinese (Taiwan standard) to Simplified Chinese
- tw2sp: Traditional Chinese (Taiwan standard) to Simplified Chinese (with phrases)

去除html标签和停用词

from bs4 import BeautifulSoup
import jieba
from glob import glob

def clean_chineses_text(text, with_space=False):
    """
    中文数据清洗  stopwords_chineses.txt存放在博客园文件中
    :param text:
    :return:
    """
    text = BeautifulSoup(text, 'html.parser').get_text() #去掉html标签
    text = jieba.lcut(text)
    stop_word_filepath_list = glob("./停用词/*.txt")
#     print(stop_word_filepath_list)
    for stop_word_filepath in stop_word_filepath_list:
        with open(stop_word_filepath) as fp:
            stopwords = {}.fromkeys([line.rstrip() for line in fp]) #加载停用词(中文)
    eng_stopwords = set(stopwords) #去掉重复的词
    words = [w for w in text if w not in eng_stopwords] #去除文本中的停用词
    if with_space:
        return ' '.join(words)
    else:
        return ''.join(words)
clean_chineses_text("你现在读的这里是简体,这里是繁体,能看懂吗?", with_space=True)
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.703 seconds.
Prefix dict has been built successfully.





'读 简体 , 这里 繁体 , 能看懂 吗 ?'
ENGLISH_STOP_WORDS = frozenset([
    "about", "above", "across", "after", "afterwards", "again", "against",
    "all", "almost", "alone", "along", "already", "also", "although", "always",
    "am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
    "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
    "around", "as", "at", "back", "be", "became", "because", "become",
    "becomes", "becoming", "been", "before", "beforehand", "behind", "being",
    "below", "beside", "besides", "between", "beyond", "bill", "both",
    "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con",
    "could", "couldnt", "cry", "de", "describe", "detail", "do", "done",
    "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else",
    "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone",
    "everything", "everywhere", "except", "few", "fifteen", "fifty", "fill",
    "find", "fire", "first", "five", "for", "former", "formerly", "forty",
    "found", "four", "from", "front", "full", "further", "get", "give", "go",
    "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter",
    "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his",
    "how", "however", "hundred", "ie", "if", "in", "inc", "indeed",
    "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter",
    "latterly", "least", "less", "ltd", "made", "many", "may", "me",
    "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly",
    "move", "much", "must", "my", "myself", "name", "namely", "neither",
    "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone",
    "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on",
    "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our",
    "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps",
    "please", "put", "rather", "re", "same", "see", "seem", "seemed",
    "seeming", "seems", "serious", "several", "she", "should", "show", "side",
    "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone",
    "something", "sometime", "sometimes", "somewhere", "still", "such",
    "system", "take", "ten", "than", "that", "the", "their", "them",
    "themselves", "then", "thence", "there", "thereafter", "thereby",
    "therefore", "therein", "thereupon", "these", "they", "thick", "thin",
    "third", "this", "those", "though", "three", "through", "throughout",
    "thru", "thus", "to", "together", "too", "top", "toward", "towards",
    "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us",
    "very", "via", "was", "we", "well", "were", "what", "whatever", "when",
    "whence", "whenever", "where", "whereafter", "whereas", "whereby",
    "wherein", "whereupon", "wherever", "whether", "which", "while", "whither",
    "who", "whoever", "whole", "whom", "whose", "why", "will", "with",
    "within", "without", "would", "yet", "you", "your", "yours", "yourself",
    "yourselves", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l",
    "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"])

特征抽取

  • BOW
  • TF-IDF
  • LDA

文本特征提取类

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer

import sys
!ls ../package/
sys.path.insert(0, "../package/")
from ltp import LTP
nlp = LTP(path="base")

from gensim.models import Word2Vec

class TextFeatures:
    def __init__(self, ngram_range=(1, 2)):
        self.cvt = CountVectorizer(tokenizer=self.tokenizer, ngram_range=ngram_range)
        self.tvt = TfidfVectorizer(tokenizer=self.tokenizer, ngram_range=ngram_range)
        self.hvt = HashingVectorizer(tokenizer=self.tokenizer, ngram_range=ngram_range)
        self.cleaner = TextCleaner(remove_html_label=True, remove_stop_words=True, with_space=True)

    def clean_text(self, text_list):
        return self.cleaner.clean_text(text_list)

    def tokenizer(self, text):
        return text.split(" ")

    def get_bow(self, text_list):
        return self.cvt.fit_transform(text_list)

    def get_tfidf(self, text_list):
        return self.tvt.fit_transform(text_list)

    def get_hashing(self, text_list):
        return self.hvt.fit_transform(text_list)
ltp


file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
train_df = pd.read_csv("../0.数据/1.情感分析/NLPCC14-SC/train.tsv", sep="\t", error_bad_lines=False)
train_df.head()
labeltext_a
set(train_df["label"]), train_df.shape
({0, 1}, (10000, 2))
cleaner = TextCleaner(remove_html_label=True, remove_stop_words=True, with_space=True)
contents = ['   大家好, 欢迎一起来学习文本的空格   去除   !']
results = cleaner.clean_text(contents)
print(results)
0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]


['好 , 学习 文本 空格 去除 !']
tqdm.pandas(desc="clean data")
train_df["cleaned_text"] = cleaner.clean_text(train_df["text_a"].values)
0%|          | 0/10000 [00:00<?, ?it/s]



  0%|          | 0/40 [00:00<?, ?it/s]
train_df.to_csv("cleaned_train.csv", index=None)
# import torch
# from tqdm.auto import tqdm

# tokenized_text = []
# text_list = list(train_df["cleaned_text"].values)
# with torch.no_grad():
#     steps = 256
#     for start_idx in tqdm(range(0, train_df.shape[0], steps)):
# #         print(start_idx)
#         if start_idx + steps > train_df.shape[0]:
#             tokenized_text += nlp.seg(text_list[start_idx:])[0]
#         else:
#             tokenized_text += nlp.seg(text_list[start_idx:start_idx+steps])[0]
# from joblib import dump, load
# 关掉显存占用
# from numba import cuda

# cuda.select_device(0)
# cuda.close()

BOW

!ls ../1.基础/停用词/
中文停用词库.txt  哈工大停用词表.txt  四川大学停用词表.txt  百度停用词表.txt
from glob import glob
# 停用词列表
stop_words = []
txt_list = glob("../1.基础/停用词/*.txt")
for txt_path in txt_list:
    with open(txt_path, "r") as fp:
        lines = fp.readlines()
    stop_words += [line.strip() for line in lines]
len(stop_words)
3893
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
def tokenizer(text):
    return text.split(" ")
# corpus = [" ".join(text_list) for text_list in tokenized_text]
# corpus[:2]
corpus = train_df["cleaned_text"].values
cvt = CountVectorizer(stop_words=stop_words, tokenizer=tokenizer, ngram_range=(1, 2))
x_cvt = cvt.fit_transform(corpus)
len(cvt.vocabulary_)
137525
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(x_cvt, y, test_size=0.1)

clf = Ridge(alpha=500.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.8657380740314067 0.798

valid score: 
0.8009079767378523 0.733

TFIDF

from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
tvt = TfidfVectorizer(stop_words=stop_words, tokenizer=tokenizer, ngram_range=(1, 2))
x_tvt = tvt.fit_transform(corpus)
len(tvt.vocabulary_)
137525
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(x_tvt, y, test_size=0.1)

clf = Ridge(alpha=10.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.9349220324539836 0.8745555555555555

valid score: 
0.7963706773775423 0.728

HashingVectorizer

from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
hvt = HashingVectorizer(stop_words=stop_words, tokenizer=tokenizer, ngram_range=(1, 2))
x_hvt = hvt.fit_transform(corpus)
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(x_hvt, y, test_size=0.1)

clf = Ridge(alpha=1.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.99204728016389 0.969

valid score: 
0.8349841394447204 0.749

LDA

train_df = pd.read_csv("./cleaned_train.csv")
train_df.head()
labeltext_acleaned_text
from glob import glob
# 停用词列表
stop_words = []
txt_list = glob("../1.基础/停用词/*.txt")
for txt_path in txt_list:
    with open(txt_path, "r") as fp:
        lines = fp.readlines()
    stop_words += [line.strip() for line in lines]
len(stop_words)
3893
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
def tokenizer(text):
    return text.split(" ")

corpus = train_df["cleaned_text"].values
corpus = [string if string is not np.nan else "" for string in corpus]
cvt = CountVectorizer(tokenizer=tokenizer, ngram_range=(1, 2))
x_cvt = cvt.fit_transform(corpus)
lda = LatentDirichletAllocation(n_components=32, doc_topic_prior=None, topic_word_prior=None, learning_method='batch', 
                                learning_decay=0.7, learning_offset=50.0, max_iter=10, batch_size=128, evaluate_every=-1, 
                                total_samples=1000000.0, perp_tol=0.1, mean_change_tol=0.001, max_doc_update_iter=100, 
                                n_jobs=None, verbose=0, random_state=402)
docres = lda.fit_transform(x_cvt)
docres.shape
(10000, 32)
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(docres, y, test_size=0.1)

clf = Ridge(alpha=500.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.5984059229289742 0.5741111111111111

valid score: 
0.5797141495568878 0.57

gensim

corpus = [string.split(" ") for string in corpus]
from gensim import corpora
dictionary = corpora.Dictionary(corpus)
dictionary.save('qzone.dict')
dictionary.filter_extremes(no_below=20, no_above=0.5)
dictionary.compactify()
corpus = [dictionary.doc2bow(s) for s in corpus]
corpora.MmCorpus.serialize('corpus_bow.mm', corpus)  # 存储语料库
from gensim.models import LdaModel

num_topics = 100
chunksize = 2000
passes = 20
iterations = 400
eval_every = None 

temp = dictionary[0]
id2word = dictionary.id2token

model = LdaModel(
    corpus=corpus,
    id2word=id2word,
    chunksize=chunksize,
    alpha='auto',
    eta='auto',
    iterations=iterations,
    num_topics=num_topics,
    passes=passes,
    eval_every=eval_every
)

model.save('qzone.model')
top_topics = model.top_topics(corpus)
avg_topic_coherence = sum([t[1] for t in top_topics]) / num_topics
print('Average topic coherence: %.4f.' % avg_topic_coherence)
Average topic coherence: -5.7200.
len(top_topics), len(corpus)
(100, 10000)

LTP特征提取

import sys
!ls ../package/

sys.path.insert(0, "../package/")

from ltp import LTP
nlp = LTP(path="base")
ltp


file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
seg, hidden = nlp.seg(["他叫汤姆去拿外衣。"])
pos = nlp.pos(hidden)
ner = nlp.ner(hidden)
srl = nlp.srl(hidden)
dep = nlp.dep(hidden)
sdp = nlp.sdp(hidden)

对于LTP提取的特征,可以参考LTP的文档

  • 静态词向量
  • 动态词向量

推荐系统的基本概念

王树森大佬又开了一门公开课:推荐系统,抱着学习的心态来学习下王老师的课。并做个笔记。

视频地址

github课件:https://github.com/wangshusen/Recomme…

基本概念:

曝光:类似系统给你的推荐的内容

点击:用户点击推荐的内容

阅读:用户点击后在页面停留一段时间

转化流程:

用户行为:点击、点赞、收藏、转发

消费指标:用于反应消费侧对推荐系统的满意程度(非最重要)

消费指标:点击率 (click rate)、交互率 (engagement rate)

北极星指标(最核心指标):用户规模、消费、发布 (关键指标)

DAU:日活跃用户数,用户本日登入小红书,就算一个DAU(且不重复计数)

MAU: 用户本月登入小红书,就算一个MAU(且不重复计数)

实验流程:离线实验、AB测试、推全

离线实验只能反映部分指标,还需要线上实验。

推荐系统链路

链路包括召回、粗排、精排、重排。

– 召回(retrieval):快速从海量数据中取回几千个用户可能感兴趣的物品。

– 粗排:用小规模的模型的神经网络给召回的物品打分,然后做截断,选出分数最高的几百个物品。

– 精排:用大规模神经网络给粗排选中的几百个物品打分,可以做截断,也可以不做截断。 – 重排:对精排结果做多样性抽样,得到几十个物品,然后用规则调整物品的排序。

当用户刷新页面时候,系统就会调用几十条召回通道,每个通道取回几百篇笔记内容,然后使用 用小规模的模型的神经网络给召回的物品打分,然后做截断,选出分数最高的几百个物品。 在下一部精排: 用大规模神经网络给粗排选中的几百个物品打分,可以做截断,也可以不做截断。最后:对精排结果做多样性抽样,得到几十个物品,然后用规则调整物品的排序。

重排

做多样性抽样(⽐如MMR、DPP),从⼏百篇中选出⼏⼗篇。
• ⽤规则打散相似笔记。
• 插⼊广告、运营推广内容,根据⽣态要求调整排序。

总结:

推荐系统的小流量A/B测试 (线上实验)

推荐系统算法工程师的日常工作就是改进模型和策略,目标是提升推荐系统的业务指标。所有对模型和策略的改进,都需要经过线上 AB 测试,用实验数据来验证模型和策略是否有效。

小流量:比如只对10%的用户开放该算法,观测用户的反馈,这样避免大范围的影响。

使用随机分桶测试不同的实验参数效果:

分层实验:解决流量不足的问题(测试的用户不足)

同层互斥,不同层正交:

实验推全和反转实验

python 包、模块的书写 以及 __all__ 变量的用法

一、模块

相信使用过Python编写代码的同学,会经常在文件头看到这样的import …,是的,这就是导入模块的语句,而每一个后缀名为.py的文件都是一个模块。

import jieba
import os 

1. 什么是模块?

  逻辑上来说模块是一组功能的组合;实质上一个模块就是一个包含了python定义和声明的文件,文件名就是模块名字加上.py的后缀。

import加载的模块分为四个通用类别:

a. 使用python编写的代码(.py文件);
b. 已被编译为共享库或DLL的C或C++扩展;
c. 包好一组模块的包
d. 使用C编写并链接到python解释器的内置模块;

如何使用模块?
  想要使用模块,必须先要将模块加载进来,可以通过关键字 import 或 from进行加载;需要注意的是模块和当前文件在不同的命名空间中。

2. 模块的构成

  模块可以包含可执行的语句和函数的定义,这些语句的目的是初始化模块,它们只在模块名第一次遇到导入import语句时才执行(import语句是可以在程序中的任意位置使用的,且针对同一个模块很import多次,为了防止你重复导入,python的优化手段是:第一次导入后就将模块名加载到内存了,后续的import语句仅是对已经加载大内存中的模块对象增加了一次引用,不会重新执行模块内的语句

二、模块的导入

1、导入整个模块

  比如我们有一个myModule的文件夹,里面有一个first.py文件,文件中的内容如下

a = 1
def myfun(s):
    print(s + 1)

  在myModule的文件夹下打开终端/cmd,输入python进入命令行交互模式
写完模块导入的语句之后,接着就可以调用该模块下的函数了。调用方式为

>>> import first
>>> a
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'a' is not defined
>>> first.a
1
>>> first.myfun(2)
3

在这里插入图片描述
2、导入特定的函数/变量

  所以说first.py文件就是一个模块,可以用import导入,里面变量和方法都要用first.前缀来引用,如果想不使用这个前缀或是我们只是想要使用模块中的某个函数,就可以只导入该变量或函数。导入方式为:from module_name import function_name。
  如果导入的是变量,就可以直接输入变量名来获得变量的值;如果直接导入的是函数,可以直接使用function_name() 的方式调用函数,无需在函数名前面加上模块名。

# 导入变量
>>> from first import a
>>> a
1
# 导入函数
>>> from first import myfun
>>> myfun(3)
4
# 一次导入多个变量
>>> from first import a,myfun
>>> a
1
>>> myfun(5)
6
# 导入模块中全部变量
>>> from first import *
>>> a
1
>>> myfun(5)
6
>>>

3、使用as给模块指定别名

  可以在后面使用as给函数指定别名。句式如:import module_name as new_name,

>>> import first as f
>>> f.a
1
>>> f.myfun(6)
7

在上述导入函数的基础上,可以在后面用as语句给导入的函数指定别名。句式如:from module_name import function_name as new_function。

>>> from first import myfun as add
>>> add(8)
9

三、包、库

模块(module) 其实就是py文件,里面定义了一些函数、类、变量等。
包(package) 是多个模块的聚合体形成的文件夹,里面可以是多个py文件,也可以嵌套文件夹。
是参考其他编程语言的说法,是指完成一定功能的代码集合,在python中的形式就是模块和包。

一个包的架构:

sound/                          Top-level package
      __init__.py               Initialize the sound package
      formats/                  Subpackage for file format conversions
              __init__.py
              wavread.py
              wavwrite.py
              aiffread.py
              aiffwrite.py
              auread.py
              auwrite.py
              ...
      effects/                  Subpackage for sound effects
              __init__.py
              echo.py
              surround.py
              reverse.py
              ...
      filters/                  Subpackage for filters
              __init__.py
              equalizer.py
              vocoder.py
              karaoke.py
              ...

Python 只把含 __init__.py 文件的目录当成包。这样可以防止以 string 等通用名称命名的目录,无意中屏蔽出现在后方模块搜索路径中的有效模块。 最简情况下,__init__.py 只是一个空文件,但该文件也可以执行包的初始化代码,或设置 __all__ 变量,详见下文。

四、包的导入

导入包的本质:导入一个包就是执行包下的__init__.py文件

只要一个文件夹下面有个__init__.py 文件,那么这个文件夹就可以看做是一个包

包导入的过程和模块的基本一致,只是导入包的时候会执行此包目录下的 init.py 而不是模块里面的语句了。另外,如果只是单纯的导入包,而包的 init.py 中又没有明确的其他初始化操作,那么此包下面的模块是不会自动导入的。

另外需要注意两点

  1. __ init__ .py文件编写时,如果要在__init__.py中导入其他模块中的变量,即使__ init__.py文件和abcd.py文件在同一个文件夹下,也不能from abcd import b,要从abcd文件从哪里来开始写,即从包的名称开始,from folder.abcd import b。
  2. folder文件夹里的嵌套文件夹内不需要新建__init__.py文件即可像模块一样调用,但是一般还是要新建这个文件,可以方便地导入常用变量。
  3. init.py文件其实是一个特殊的文件,它相当于名为folder模块,即如果使用import folder则可以调用在__init__.py文件文件中定义的变量。

五、__ all __

使用 from sound.effects import * 时会发生什么?理想情况下,该语句在文件系统查找并导入包的所有子模块。这项操作花费的时间较长,并且导入子模块可能会产生不必要的副作用,这种副作用只有在显式导入子模块时才会发生。

唯一的解决方案是提供包的显式索引。import 语句使用如下惯例:如果包的 __init__.py 代码定义了列表 __all__,运行 from package import * 时,它就是用于导入的模块名列表。发布包的新版本时,包的作者应更新此列表。如果包的作者认为没有必要在包中执行导入 * 操作,也可以不提供此列表。例如,sound/effects/__init__.py 文件包含以下代码:

__all__ = ["echo", "surround", "reverse"]

这将意味着将 from sound.effects import * 导入 sound.effects 包的三个命名的子模块。

如果没有定义 __all__from sound.effects import * 语句 不会 把包 sound.effects 中所有子模块都导入到当前命名空间;该语句只确保导入包 sound.effects (可能还会运行 __init__.py 中的初始化代码),然后,再导入包中定义的名称。这些名称包括 __init__.py 中定义的任何名称(以及显式加载的子模块),还包括之前 import 语句显式加载的包里的子模块。

变量__all__的好处:只会导出all中的子模块,可以有效地避免命名空间的污染,并加速模块的导入

一、模块公开接口的一种约定
__all__可以在模块级别暴露接口,形式如下:
__all__ = [“foo”, “bar”]
Python 没有原生的可见性控制,其可见性的维护是靠一套需要大家自觉遵守的”约定“,比如,下划线开头的变量对外部不可见。
__all__ 是针对模块公开接口的一种约定,以提供了”白名单“的形式暴露接口。如果定义了__all__,其他文件中使用from xxx import *导入该文件时,只会导入 __all__ 列出的成员,可以其他成员都被排除在外。
如,test1.py,test2.py,test3.py三个文件:
test1.py
#__all__ = [‘func’]
def func():
pass

test2.py
import test1

__all__ = [‘func2’, ‘test1’]
def func2():
pass

def func22():
pass

test3.py
from test2 import *

func2() #能正常引用
test1.func() #能正常引用
func22() #不能正常引用

二、控制 from xxx import * 的行为
python不提倡用 from xxx import * 这种写法。如果一个模块 xxx 没有定义 __all__,执行 from spam import * 时会将 xxx 中所有非下划线开头的成员(包括该模块import的其他模块成员)都会导入当前命名空间,这样就可能弄脏当前的命名空间。显式声明了 __all__,import * 就只会导入 __all__ 列出的成员,如果 __all__ 定义有误,还会明确地抛出异常,方便检查错误。

三、为 lint 等代码检查工具提供辅助
编写库时,经常会在 __init__.py 中暴露整个包的 API,而这些 API 的实现可能是在包的其他模块中。如果仅仅这样写:from xxx import a, b,一些代码检查工具,如 pyflakes 会报错,认为变量 a和 b import 了但没被使用。一个可行的方法是把这个警告压掉:from xxx import a, b # noqa (No Q/A,即无质量保证),但更好的方法是显式定义 __all__,这样代码检查工具就会理解,从而不再报 unused variables 的警告。

四、定义 all 需要注意的地方

  • __all__ 的形式都是 list类型。如果写成其他类型, pyflakes 等 lint 工具可能无法识别。
  • 不能动态生成 __all__,如使用列表解析式。__all__ 的作用是定义公开接口,需要以字面量的形式显式写出来。
  • 即使定义了 __all__, 也不应该在非临时代码中使用 from xxx import * 语法,或用编程工具模拟 Ruby 的自动 import。Python 不像 Ruby,没有 Module 这类成员,模块就是命名空间隔离的执行者。如果打破了这一层,引入诸多动态因素,生产环境中跑的代码就可能充满不确定性,调试也会变得困难。
  • 按照 PEP8 建议的风格,__all__ 应该写在所有 import 语句下面,函数、常量等成员定义的上面。
  • 如果一个模块需要暴露的接口改动频繁,__all__ 可以这样定义:

__all__ = [
“foo”,
“bar”,
“egg”,
]
这样修改一个暴露的接口只修改一行,方便版本控制的时候看 diff。最后多出的逗号在 Python 中是允许的,符合 PEP8 风格。

由上面的输出结果,我们可以知道import *只会导入__all__中指定的变量,无论是否以下划线开头。这样限制可以防止import *命令导入太多变量污染命名空间,过滤掉一些中间变量如b

五、模块导入的绝对引用与相对引用

python中的import分为绝对引用和相对引用两种。它们之间的差异在于,引用模块时,定位被引用模块位置 的方式不同。

绝对引用是通过.的连接,指定出最高级文件(夹),到目标文件的绝对路径。我们上面的所有用法都属于绝对引用。

而相对引用是指定待引用模块与当前文件的相对位置,.表示上一级文件

  • 绝对引用:from folder.abcd import myclass
  • 相对引用:from .abcd import myclass

在实际使用中,无论是绝对导入还是相对导入都要注意,如何导入与被调用位置有关。

Pytorch 中 model.eval() model.train() 和 with torch.no_grad() 的区别

1、model.eval() model.train()区别

model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

官方文档 model.train()
启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和 Dropout,需要在训练时添加model.train()model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

官方文档 model.eval()
不启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

2 . model.eval()和with torch.no_grad()的区别:


在PyTorch中进行validation时,会使用model.eval()切换到测试模式,在该模式下,

主要用于通知dropout层和batchnorm层在train和val模式间切换
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。
在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传(backprobagation)


with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止gradient计算,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。


使用场景:
如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation的结果;而with torch.zero_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储gradient),从而可以更快计算,也可以跑更大的batch来测试。

Python装饰器:python中的@符号的作用 以及 torch中经常出现的 @torch.no_grad()

@符号是装饰器(修饰符)的语法糖,在定义函数的时候使用,避免再一次赋值操作

装饰器(Decorators)是 Python 的一个重要部分。简单地说:他们是修改其他函数的功能的函数。他们有助于让我们的代码更简短,也更Pythonic(Python范儿)。大多数初学者不知道在哪儿使用它们,所以我将要分享下,哪些区域里装饰器可以让你的代码更简洁。 首先,让我们讨论下如何写你自己的装饰器。

‘@’符号用作函数修饰符是python2.4新增加的功能,修饰符必须出现在函数定义前一行,不允许和函数定义在同一行。也就是说@A def f(): 是非法的。只可以在模块或类定义层内对函数进行修饰,不允许修饰一个类。一个修饰符就是一个函数,它将被修饰的函数做为参数,并返回修饰后的同名函数或其它可调用的东西。

实例(1):

def spamrun(fn):
   def sayspam(*args):
       print("spam,spam,spam")
   return sayspam
@spamrun
def useful(a,b):
   print (a**2+b**2)

执行: useful(3,4)

返回:spam,spam,spam

def addspam(fn):
   def new(*args):
       print "spam,spam,spam"
       return fn(*args)
   return new

@addspam
def useful(a,b):
   print a**2+b**2

执行: useful(4,3)

结果:

spam,spam,spam

25

@torch.no_grad()

@torch.no_grad()
def eval():
	...

@torch.no_grad()后面的函数的数据不需要计算梯度,也不会进行反向传播

Python装饰器:

装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能,装饰器的返回值也是一个函数对象。它经常用于有切面需求的场景,比如:插入日志、性能测试、事务处理、缓存、权限校验等场景。装饰器是解决这类问题的绝佳设计,有了装饰器,我们就可以抽离出大量与函数功能本身无关的雷同代码并继续重用。概括的讲,装饰器的作用就是为已经存在的对象添加额外的功能。

先来看一个简单例子:

def foo():
    print('i am foo')

现在有一个新的需求,希望可以记录下函数的执行日志,于是在代码中添加日志代码:

def foo():
    print('i am foo')
    logging.info("foo is running")

bar()、bar2()也有类似的需求,怎么做?再写一个logging在bar函数里?这样就造成大量雷同的代码,为了减少重复写代码,我们可以这样做,重新定义一个函数:专门处理日志 ,日志处理完之后再执行真正的业务代码

def use_logging(func):
    logging.warn("%s is running" % func.__name__)
    func()

def bar():
    print('i am bar')

use_logging(bar)

逻辑上不难理解, 但是这样的话,我们每次都要将一个函数作为参数传递给use_logging函数。而且这种方式已经破坏了原有的代码逻辑结构,之前执行业务逻辑时,执行运行bar(),但是现在不得不改成use_logging(bar)。那么有没有更好的方式的呢?当然有,答案就是装饰器。

简单装饰器

def use_logging(func):

    def wrapper(*args, **kwargs):
        logging.warn("%s is running" % func.__name__)
        return func(*args, **kwargs)
    return wrapper

def bar():
    print('i am bar')

bar = use_logging(bar)
bar()

函数use_logging就是装饰器,它把执行真正业务方法的func包裹在函数里面,看起来像bar被use_logging装饰了。在这个例子中,函数进入和退出时 ,被称为一个横切面(Aspect),这种编程方式被称为面向切面的编程(Aspect-Oriented Programming)。

@符号是装饰器的语法糖,在定义函数的时候使用,避免再一次赋值操作

方法一:不用语法糖@符号​​​​​​​

# 装饰器不传入参数时
f = decorator(函数名)

# 装饰器传入参数时
f = (decorator(参数))(函数名)


方法二:采用语法糖@符号​​​​​​​

# 已定义的装饰器
@decorator 
def f():  
    pass

# 执行被装饰过的函数 
f()
def use_logging(func):

    def wrapper(*args, **kwargs):
        logging.warn("%s is running" % func.__name__)
        return func(*args)
    return wrapper

@use_logging
def foo():
    print("i am foo")

@use_logging
def bar():
    print("i am bar")

bar()

如上所示,这样我们就可以省去bar = use_logging(bar)这一句了,直接调用bar()即可得到想要的结果。如果我们有其他的类似函数,我们可以继续调用装饰器来修饰函数,而不用重复修改函数或者增加新的封装。这样,我们就提高了程序的可重复利用性,并增加了程序的可读性。

装饰器在Python使用如此方便都要归因于Python的函数能像普通的对象一样能作为参数传递给其他函数,可以被赋值给其他变量,可以作为返回值,可以被定义在另外一个函数内。

带参数的装饰器

装饰器还有更大的灵活性,例如带参数的装饰器:在上面的装饰器调用中,比如@use_logging,该装饰器唯一的参数就是执行业务的函数。装饰器的语法允许我们在调用时,提供其它参数,比如@decorator(a)。这样,就为装饰器的编写和使用提供了更大的灵活性。

def use_logging(level):
    def decorator(func):
        def wrapper(*args, **kwargs):
            if level == "warn":
                logging.warn("%s is running" % func.__name__)
            return func(*args)
        return wrapper

    return decorator

@use_logging(level="warn")
def foo(name='foo'):
    print("i am %s" % name)

foo()

上面的use_logging是允许带参数的装饰器。它实际上是对原有装饰器的一个函数封装,并返回一个装饰器。我们可以将它理解为一个含有参数的闭包。当我 们使用@use_logging(level=”warn”)调用的时候,Python能够发现这一层的封装,并把参数传递到装饰器的环境中。

类装饰器

再来看看类装饰器,相比函数装饰器,类装饰器具有灵活度大、高内聚、封装性等优点。使用类装饰器还可以依靠类内部的__call__方法,当使用 @ 形式将装饰器附加到函数上时,就会调用此方法。

__call__方法 : 在生成一个类的实例时,自动自行一次call方法

当执行Foo时候生成一个实例,就会自动调用__call__方法

class Foo(object):
    def __init__(self, func):
    self._func = func

def __call__(self):
    print ('class decorator runing')
    self._func()
    print ('class decorator ending')

@Foo
def bar():
    print ('bar')

bar()

functools.wraps

使用装饰器极大地复用了代码,但是他有一个缺点就是原函数的元信息不见了,比如函数的docstring、__name__、参数列表,先看例子:

装饰器

def logged(func):
    def with_logging(*args, **kwargs):
        print func.__name__ + " was called"
        return func(*args, **kwargs)
    return with_logging

函数

@logged
def f(x):
   """does some math"""
   return x + x * x

该函数完成等价于:

def f(x):
    """does some math"""
    return x + x * x
f = logged(f)

不难发现,函数f被with_logging取代了,当然它的docstring,__name__就是变成了with_logging函数的信息了。

print f.__name__    # prints 'with_logging'
print f.__doc__     # prints None

这个问题就比较严重的,好在我们有functools.wraps,wraps本身也是一个装饰器,它能把原函数的元信息拷贝到装饰器函数中,这使得装饰器函数也有和原函数一样的元信息了。

from functools import wraps
def logged(func):
    @wraps(func)
    def with_logging(*args, **kwargs):
        print func.__name__ + " was called"
        return func(*args, **kwargs)
    return with_logging

@logged
def f(x):
    """does some math"""
    return x + x * x

print f.__name__  # prints 'f'
print f.__doc__   # prints 'does some math'

内置装饰器

@staticmathod、@classmethod、@property

@property

把类内方法当成属性来使用,必须要有返回值,相当于getter;

假如没有定义 @func.setter 修饰方法的话,就是只读属性

class Car:

    def __init__(self, name, price):
        self._name = name
        self._price = price    
     
    @property
    def car_name(self):
        return self._name
        
     # car_name可以读写的属性   
     @car_name.setter
     def car_name(self, value):
         self._name = value
         
     # car_price是只读属性 
     @property
     def car_price(self):
         return str(self._price) + '万'
         
benz = Car('benz', 30)

print(benz.car_name)   # benz
benz.car_name = "baojun"
print(benz.car_name)   # baojun
print(benz.car_price)  # 30万

@staticmethod

静态方法,不需要表示自身对象的self和自身类的cls参数,就跟使用函数一样。

静态方法的使用场景:

如果在方法中不需要访问任何实例方法和属性,纯粹地通过传入参数并返回数据的功能性方法,那么它就适合用静态方法来定义,它节省了实例化对象的开销成本,往往这种方法放在类外面的模块层作为一个函数存在也是没问题的,而放在类中,仅为这个类服务。

@classmethod

类方法,不需要self参数,但第一个参数需要是表示自身类的cls参数。

类方法的使用场景有:

作为工厂方法创建实例对象,例如内置模块 datetime.date 类中就有大量使用类方法作为工厂方法,以此来创建date对象。如果希望在方法里面调用静态类,那么把方法定义成类方法是合适的,因为要是定义成静态方法,那么你就要显示地引用类A,这对继承来说可不是一件好事情。

例子

class Demo(object):

    text = "三种方法的比较"
    
    def instance_method(self):
        print("调用实例方法")

    @classmethod
    def class_method(cls):
        print("调用类方法")
        print("在类方法中 访问类属性 text: {}".format(cls.text))
        print("在类方法中 调用实例方法 instance_method: {}".format(cls().instance_method()))

    @staticmethod
    def static_method():
        print("调用静态方法")
        print("在静态方法中 访问类属性 text: {}".format(Demo.text))
        print("在静态方法中 调用实例方法 instance_method: {}".format(Demo().instance_method()))

if __name__ == "__main__":
    # 实例化对象
    d = Demo()
    
    # 对象可以访问 实例方法、类方法、静态方法
    # 通过对象访问text属性
    print(d.text)
    
    # 通过对象调用实例方法
    d.instance_method()
    
    # 通过对象调用类方法
    d.class_method()
    
    # 通过对象调用静态方法
    d.static_method()
    
    # 类可以访问类方法、静态方法
    # 通过类访问text属性
    print(Demo.text)
    
    # 通过类调用类方法
    Demo.class_method()
    
    # 通过类调用静态方法
    Demo.static_method()

@staticmethod 和 @classmethod 的 区别 和 使用场景

在上述例子中,我们可以看出,

区别

在定义静态类方法和类方法时,@staticmethod 装饰的静态方法里面,想要访问类属性或调用实例方法,必须需要把类名写上;

@classmethod装饰的类方法里面,会传一个cls参数,代表本类,这样就能够避免手写类名的硬编码。

在调用静态方法和类方法时,实际上写法都差不多,一般都是通过 类名.静态方法() 或 类名.类方法()。也可以用实例对象调用类方法和静态方法。 对象可以访问 实例方法、类方法、静态方法 , 类可以访问类方法、静态方法

也可以用实例化对象去调用静态方法和类方法但为了和实例方法区分,最好还是用类去调用静态方法和类方法。

使用场景

所以,在定义类的时候,

假如不需要用到与类相关的属性或方法时,就用静态方法@staticmethod

假如需要用到与类相关的属性或方法,然后又想表明这个方法是整个类通用的,而不是对象特异的,就可以使用类方法@classmethod

装饰器的顺序

@a
@b
@c
def f ():

等效于

f = a(b(c(f)))