PyTorch 中支持更多 ONNX 算子

学习了 PyTorch 转 ONNX 的方法,可以发现 PyTorch 对 ONNX 的支持还不错。但在实际的部署过程中,难免碰到模型无法用原生 PyTorch 算子表示的情况。这个时候,我们就得考虑扩充 PyTorch,即在 PyTorch 中支持更多 ONNX 算子。

而要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:

  • 算子在 PyTorch 中有实现
  • 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
  • ONNX 有相应的算子

可在实际部署中,这三部分的内容都可能有所缺失。其中最坏的情况是:我们定义了一个全新的算子,它不仅缺少 PyTorch 实现,还缺少 PyTorch 到 ONNX 的映射关系。但所谓车到山前必有路,对于这三个环节,我们也分别都有以下的添加支持的方法:

  • PyTorch 算子
    • 组合现有算子
    • 添加 TorchScript 算子
    • 添加普通 C++ 拓展算子
  • 映射方法
    • 为 ATen 算子添加符号函数
    • 为 TorchScript 算子添加符号函数
    • 封装成 torch.autograd.Function 并添加符号函数
  • ONNX 算子
    • 使用现有 ONNX 算子
    • 定义新 ONNX 算子

那么面对不同的情况时,就需要我们灵活地选用和组合这些方法。听起来是不是很复杂?别担心,本篇文章中,我们将围绕着三种算子映射方法,学习三个添加算子支持的实例,来理清如何合适地为 PyTorch 算子转 ONNX 算子的三个环节添加支持。

 支持 ATen 算子

实际的部署过程中,我们都有可能会碰到一个最简单的算子缺失问题: 算子在 ATen 中已经实现了,ONNX 中也有相关算子的定义,但是相关算子映射成 ONNX 的规则没有写。在这种情况下,我们只需要为 ATen 算子补充描述映射规则的符号函数就行了。

ATen 是 PyTorch 内置的 C++ 张量计算库,PyTorch 算子在底层绝大多数计算都是用 ATen 实现的。

上期习题中,我们曾经提到了 ONNX 的 Asinh 算子。这个算子在 ATen 中有实现,却缺少了映射到 ONNX 算子的符号函数。在这里,我们来尝试为它补充符号函数,并导出一个包含这个算子的 ONNX 模型。

获取 ATen 中算子接口定义

为了编写符号函数,我们需要获得 asinh 推理接口的输入参数定义。这时,我们要去 torch/_C/_VariableFunctions.pyi 和 torch/nn/functional.pyi 这两个文件中搜索我们刚刚得到的这个算子名。这两个文件是编译 PyTorch 时本地自动生成的文件,里面包含了 ATen 算子的 PyTorch 调用接口。通过搜索,我们可以知道 asinh 在文件 torch/_C/_VariableFunctions.pyi 中,其接口定义为:

def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... 

经过这些步骤,我们确认了缺失的算子名为 asinh,它是一个有实现的 ATen 算子。我们还记下了 asinh 的调用接口。接下来,我们要为它补充符号函数,使它在转换成 ONNX 模型时不再报错。

添加符号函数

到目前为止,我们已经多次接触了定义 PyTorch 到 ONNX 映射规则的符号函数了。现在,我们向大家正式介绍一下符号函数。

符号函数,可以看成是 PyTorch 算子类的一个静态方法。在把 PyTorch 模型转换成 ONNX 模型时,各个 PyTorch 算子的符号函数会被依次调用,以完成 PyTorch 算子到 ONNX 算子的转换。符号函数的定义一般如下:

def symbolic(g: torch._C.Graph, input_0: torch._C.Value, input_1: torch._C.Value, ...): 

其中,torch._C.Graph 和 torch._C.Value 都对应 PyTorch 的 C++ 实现里的一些类。我们在这篇文章不深究它们的细节(感兴趣的话可以参考我们的 TorchScript 系列文章中对 trace 机制的解读),只需要知道第一个参数就固定叫 g,它表示和计算图相关的内容;后面的每个参数都表示算子的输入,需要和算子的前向推理接口的输入相同。对于 ATen 算子来说,它们的前向推理接口就是上述两个 .pyi 文件里的函数接口。

g 有一个方法 op。在把 PyTorch 算子转换成 ONNX 算子时,需要在符号函数中调用此方法来为最终的计算图添加一个 ONNX 算子。其定义如下:

def op(name: str, input_0: torch._C.Value, input_1: torch._C.Value, ...) 

其中,第一个参数是算子名称。如果该算子是普通的 ONNX 算子,只需要把它在 ONNX 官方文档里的名称填进去即可(我们稍后再讲其他情况)。

在最简单的情况下,我们只要把 PyTorch 算子的输入用g.op()一一对应到 ONNX 算子上即可,并把g.op()的返回值作为符号函数的返回值。在情况更复杂时,我们转换一个 PyTorch 算子可能要新建若干个 ONNX 算子。

补充完了背景知识,让我们回到 asinh 算子上,来为它编写符号函数。我们先去翻阅一下 ONNX 算子文档,学习一下我们在符号函数里的映射关系 g.op() 里应该怎么写。Asinh 的文档写道:该算子有一个输入 input,一个输出 output,二者的类型都为张量。

到这里,我们已经完成了信息收集环节。我们在上一小节得知了 asinh 的推理接口定义,在这一小节里收集了 ONNX 算子 Asinh 的定义。现在,我们可以用代码来补充这二者的映射关系了。在刚刚导出 asinh 算子的代码中,我们添加以下内容:

from torch.onnx.symbolic_registry import register_op 
 
def asinh_symbolic(g, input, *, out=None): 
    return g.op("Asinh", input) 
 
register_op('asinh', asinh_symbolic, '', 9)  

这里的asinh_symbolic就是asinh的符号函数。从除g以外的第二个输入参数开始,其输入参数应该严格对应它在 ATen 中的定义:

def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ... 

在符号函数的函数体中,g.op("Asinh", input)则完成了 ONNX 算子的定义。其中,第一个参数"Asinh"是算子在 ONNX 中的名称。至于第二个参数 input,如我们刚刚在文档里所见,这个算子只有一个输入,因此我们只要把符号函数的输入参数 input 对应过去就行。ONNX 的 Asinh 的输出和 ATen 的 asinh 的输出是一致的,因此我们直接把 g.op() 的结果返回即可。

定义完符号函数后,我们要把这个符号函数和原来的 ATen 算子“绑定”起来。这里,我们要用到 register_op 这个 PyTorch API 来完成绑定。如示例所示,只需要一行简单的代码即可把符号函数 asinh_symbolic 绑定到算子 asinh 上:

register_op('asinh', asinh_symbolic, '', 9) 

register_op的第一个参数是目标 ATen 算子名,第二个是要注册的符号函数,这两个参数很好理解。第三个参数是算子的“域”,对于普通 ONNX 算子,直接填空字符串即可。第四个参数表示向哪个算子集版本注册。我们遵照 ONNX 标准,向第 9 号算子集注册。值得注意的是,这里向第 9 号算子集注册,不代表较新的算子集(第 10 号、第 11 号……)都得到了注册。在示例中,我们先只向第 9 号算子集注册。

整理一下,我们最终的代码如下:

import torch 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, x): 
        return torch.asinh(x) 
 
from torch.onnx.symbolic_registry import register_op 
 
def asinh_symbolic(g, input, *, out=None): 
    return g.op("Asinh", input) 
 
register_op('asinh', asinh_symbolic, '', 9) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, input, 'asinh.onnx') 
 

成功导出的话,asinh.onnx 应该长这个样子:

测试算子

在完成了一份自定义算子后,我们一定要测试一下算子的正确性。一般我们要用 PyTorch 运行一遍原算子,再用推理引擎(比如 ONNX Runtime)运行一下 ONNX 算子,最后比对两次的运行结果。对于我们刚刚得到的 asinh.onnx,可以用如下代码来验证:

import onnxruntime 
import torch 
import numpy as np 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, x): 
        return torch.asinh(x) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch_output = model(input).detach().numpy() 
 
sess = onnxruntime.InferenceSession('asinh.onnx') 
ort_output = sess.run(None, {'0': input.numpy()})[0] 
 
assert np.allclose(torch_output, ort_output) 

在这份代码里,我们用 PyTorch 做了一遍推理,并把结果转成了 numpy 格式。之后,我们又用 ONNX Runtime 对 onnx 文件做了一次推理。

忘了 ONNX Runtime 的调用方法的话,欢迎回顾第一篇教程~

最后,我们使用 np.allclose 来保证两个结果张量的误差在一个可以允许的范围内。一切正常的话,运行这段代码后,assert 所在行不会报错,程序应该没有任何输出。

支持 TorchScript 算子

对于一些比较复杂的运算,仅使用 PyTorch 原生算子是无法实现的。这个时候,就要考虑自定义一个 PyTorch 算子,再把它转换到 ONNX 中了。新增 PyTorch 算子的方法有很多,PyTorch 官方比较推荐的一种做法是添加 TorchScript 算子 。

由于添加算子的方法较繁琐,我们今天跳过新增 TorchScript 算子的内容,以可变形卷积(Deformable Convolution)算子为例,介绍为现有 TorchScript 算子添加 ONNX 支持的方法。

可变形卷积(Deformable Convolution)是在 Torchvision 中实现的 TorchScript 算子,虽然尚未得到广泛支持,但是出现在许多模型中。

有了支持 ATen 算子的经验之后,我们可以知道为算子添加符号函数一般要经过以下几步:

  1. 获取原算子的前向推理接口。
  2. 获取目标 ONNX 算子的定义。
  3. 编写符号函数并绑定。

在为可变形卷积添加符号函数时,我们也可以尝试走一遍这个流程。

使用 TorchScript 算子

和之前一样,我们首先定义一个包含了算子的模型,为之后转换 ONNX 模型做准备。

import torch 
import torchvision 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.conv1 = torch.nn.Conv2d(3, 18, 3) 
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3) 
 
    def forward(self, x): 
        return self.conv2(x, self.conv1(x)) 

其中,torchvision.ops.DeformConv2d 就是 Torchvision 中的可变形卷积层。相比于普通卷积,可变形卷积的其他参数都大致相同,唯一的区别就是在推理时需要多输入一个表示偏移量的张量。

然后,我们查询算子的前向推理接口。DeformConv2d 层最终会调用 deform_conv2d 这个算子。我们可以在 torchvision/csrc/ops/deform_conv2d.cpp 中查到该算子的调用接口:

m.def(TORCH_SELECTIVE_SCHEMA( 
      "torchvision::deform_conv2d(Tensor input,  
      Tensor weight,  
      Tensor offset,  
      ...... 
      bool use_mask) -> Tensor")); 

那么接下来,根据之前的经验,我们就是要去 ONNX 官方文档中查找算子的定义了。

自定义 ONNX 算子

很遗憾的是,如果我们去 ONNX 的官方算子页面搜索 “deform”,将搜不出任何内容。目前,ONNX 还没有提供可变形卷积的算子,我们要自己定义一个 ONNX 算子了。

我们在前面讲过,g.op() 是用来定义 ONNX 算子的函数。对于 ONNX 官方定义的算子,g.op() 的第一个参数就是该算子的名称。而对于一个自定义算子,g.op() 的第一个参数是一个带命名空间的算子名,比如:

g.op("custom::deform_conv2d, ...) 

其中,”::”前面的内容就是我们的命名空间。该概念和 C++ 的命名空间类似,是为了防止命名冲突而设定的。如果在 g.op() 里不加前面的命名空间,则算子会被默认成 ONNX 的官方算子。

PyTorch 在运行 g.op() 时会对官方的算子做检查,如果算子名有误,或者算子的输入类型不正确, g.op() 就会报错。为了让我们随心所欲地定义新 ONNX 算子,我们必须设定一个命名空间,给算子取个名,再定义自己的算子。

我们在第一篇教程讲过:ONNX 是一套标准,本身不包括实现。在这里,我们就简略地定义一个 ONNX 可变形卷积算子,而不去写它在某个推理引擎上的实现。在后续的文章中,我们再介绍在各个推理引擎中添加新 ONNX 算子支持的方法。此处,我们只关心如何导出一个包含新 ONNX 算子节点的 onnx 文件。因此,我们可以为新算子编写如下简单的符号函数:

@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none") 
def symbolic(g,  
        input, 
        weight, 
        offset, 
        mask, 
        bias, 
        stride_h, stride_w, 
        pad_h, pad_w, 
        dil_h, dil_w, 
        n_weight_grps, 
        n_offset_grps, 
        use_mask): 
    return g.op("custom::deform_conv2d", input, offset) 
 

在这个符号函数中,我们以刚刚搜索到的算子输入参数作为符号函数的输入参数,并只用 input 和 offset 来构造一个简单的 ONNX 算子。

这段代码中,最令人疑惑的就是装饰器 @parse_args 了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如”v”表示 Torch 库里的 value 类型,一般用于标注张量,而”i”表示 int 类型,”f”表示 float 类型,”none”表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py (https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_helper.py)中查看。这里输入参数中的 input, weight, offset, mask, bias 都是张量,所以用”v”表示。后面的其他参数同理。我们不必纠结于 @parse_args 的原理,根据实际情况对符号函数的参数标注类型即可。

有了符号函数后,我们通过如下的方式注册符号函数:

register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9) 

和前面的 register_op 类似,注册符号函数时,我们要输入算子名、符号函数、算子集版本。与前面不同的是,这里的算子集版本是最早生效版本,在这里设定版本 9,意味着之后的第 10 号、第 11 号……版本集都能使用这个新算子。

最后,我们完整的模型导出代码如下:

import torch 
import torchvision 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.conv1 = torch.nn.Conv2d(3, 18, 3) 
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3) 
 
    def forward(self, x): 
        return self.conv2(x, self.conv1(x)) 
 
from torch.onnx import register_custom_op_symbolic 
from torch.onnx.symbolic_helper import parse_args 
 
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none") 
def symbolic(g,  
        input, 
        weight, 
        offset, 
        mask, 
        bias, 
        stride_h, stride_w, 
        pad_h, pad_w, 
        dil_h, dil_w, 
        n_weight_grps, 
        n_offset_grps, 
        use_mask): 
    return g.op("custom::deform_conv2d", input, offset) 
 
register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, input, 'dcn.onnx') 
 

代码成功运行的话,我们应该能得到如下的 ONNX 模型:

可以看到,我们自定义的 ONNX 算子 deform_conv2d 包含了两个输入,一个输出,和我们预想得一样。

使用 torch.autograd.Function

最后,我们来学习一种简单的为 PyTorch 添加 C++ 算子实现的方法,来代替较为复杂的新增 TorchScript 算子。同时,我们会用 torch.autograd.Function 封装这个新算子。torch.autograd.Function 能完成算子实现和算子调用的隔离。不管算子是怎么实现的,它封装后的使用体验以及 ONNX 导出方法会和原生的 PyTorch 算子一样。这是我们比较推荐的为算子添加 ONNX 支持的方法。

为了应对更复杂的情况,我们来自定义一个奇怪的 my_add 算子。这个算子的输入张量 a, b ,输出 2a + b 的值。我们会先把它在 PyTorch 中实现,再把它导出到 ONNX 中。

为 PyTorch 添加 C++ 拓展

为 PyTorch 添加简单的 C++ 拓展还是很方便的。对于我们定义的 my_add 算子,可以用以下的 C++ 源文件来实现。我们把该文件命名为 “my_add.cpp”:

// my_add.cpp 
 
#include <torch/torch.h> 
 
torch::Tensor my_add(torch::Tensor a, torch::Tensor b) 
{ 
    return 2 * a + b; 
} 
 
PYBIND11_MODULE(my_lib, m) 
{ 
    m.def("my_add", my_add); 
} 

由于在 PyTorch 中添加 C++ 拓展和模型部署关系不大,这里我们仅给出这个简单的示例,并不对其原理做过多讲解。

在这段代码中,torch::Tensor 就是 C++ 中 torch 的张量类型,它的加法和乘法等运算符均已重载。因此,我们可以像对普通标量一样对张量做加法和乘法。

轻松地完成了算子的实现后,我们用 PYBIND11_MODULE 来为 C++ 函数提供 Python 调用接口。这里的 my_lib 是我们未来要在 Python 里导入的模块名。双引号中的 my_add 是 Python 调用接口的名称,这里我们对齐 C++ 函数的名称,依然用 “my_add”这个名字。

之后,我们可以编写如下的 Python 代码并命名为 “setup.py”,来编译刚刚的 C++ 文件:

from setuptools import setup 
from torch.utils import cpp_extension 
 
setup(name='my_add', 
      ext_modules=[cpp_extension.CppExtension('my_lib', ['my_add.cpp'])], 
      cmdclass={'build_ext': cpp_extension.BuildExtension}) 

这段代码使用了 Python 的 setuptools 编译功能和 PyTorch 的 C++ 拓展工具函数,可以编译包含了 torch 库的 C++ 源文件。这里我们需要填写的只有模块名和模块中的源文件名。我们刚刚把模块命名为 my_lib,而源文件只有一个 my_add.cpp,因此拓展模块那一行要写成 ext_modules=[cpp_extension.CppExtension('my_lib', ['my_add.cpp'])],

之后,像处理普通的 Python 包一样执行安装命令,我们的 C++ 代码就会自动编译了。

python setup.py develop 

用 torch.autograd.Function 封装

直接用 Python 接口调用 C++ 函数不太“美观”,一种比较优雅的做法是把这个调用接口封装起来。这里我们用 torch.autograd.Function 来封装算子的底层调用:

import torch 
import my_lib 
class MyAddFunction(torch.autograd.Function): 
 
    @staticmethod 
    def forward(ctx, a, b): 
        return my_lib.my_add(a, b) 
 
    @staticmethod 
    def symbolic(g, a, b): 
        two = g.op("Constant", value_t=torch.tensor([2])) 
        a = g.op('Mul', a, two) 
        return g.op('Add', a, b) 

我们在前面的教程中已经见过 torch.autograd.Function,这里我们正式地对其做一个介绍。Function 类本身表示 PyTorch 的一个可导函数,只要为其定义了前向推理和反向传播的实现,我们就可以把它当成一个普通 PyTorch 函数来使用。

PyTorch 会自动调度该函数,合适地执行前向和反向计算。对模型部署来说,Function 类有一个很好的性质:如果它定义了 symbolic 静态方法,该 Function 在执行 torch.onnx.export() 时就可以根据 symbolic 中定义的规则转换成 ONNX 算子。这个 symbolic 就是前面提到的符号函数,只是它的名称必须是 symbolic 而已。

在 forward 函数中,我们用 my_lib.my_add(a, b) 就可以调用之前写的C++函数了。这里 my_lib 是库名,my_add 是函数名,这两个名字是在前面C++的 PYBIND11_MODULE 中定义的。

在 symbolic 函数中,我们用 g.op() 定义了三个算子:常量、乘法、加法。这里乘法和加法的用法和前面提到的 asinh 一样,只需要根据 ONNX 算子定义规则把输入参数填入即可。而在定义常量算子时,我们要把 PyTorch 张量的值传入 value_t 参数中。

在 ONNX 中,我们需要把新建常量当成一个算子来看待,尽管这个算子并不会以节点的形式出现在 ONNX 模型的可视化结果里。

把算子封装成 Function 后,我们可以把 my_add算子用起来了。

my_add = MyAddFunction.apply 
 
class MyAdd(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, a, b): 
        return my_add(a, b) 

在这份代码里,我们先用 my_add = MyAddFunction.apply 获取了一个奇怪的变量。这个变量是用来做什么的呢?其实,applytorch.autograd.Function 的一个方法,这个方法完成了 Function 在前向推理或者反向传播时的调度。我们在使用 Function 的派生类做推理时,不应该显式地调用 forward(),而应该调用其 apply 方法。

这里我们使用 my_add = MyAddFunction.apply 把这个调用方法取了一个更简短的别名 my_add。以后在使用 my_add 算子时,我们应该忽略 MyAddFunction 的实现细节,而只通过 my_add 这个接口来访问算子。这里 my_add 的地位,和 PyTorch 的 asinhinterpolateconv2d等原生函数是类似的。

有了访问新算子的接口后,我们可以进一步把算子封装成一个神经网络中的计算层。我们定义一个叫做的 MyAdd 的 torch.nn.Module,它封装了my_add,就和封装了conv2d 的 torch.nn.Conv2d 一样。

测试算子

费了好大的功夫来“包装”我们的新算子后,我们终于可以来使用它了。和之前的测试流程一样,让我们用下面的代码来导出一个包含新算子的 ONNX 模型,并验证一下它是否正确。

model = MyAdd() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, (input, input), 'my_add.onnx') 
torch_output = model(input, input).detach().numpy() 
 
import onnxruntime 
import numpy as np 
sess = onnxruntime.InferenceSession('my_add.onnx') 
ort_output = sess.run(None, {'a': input.numpy(), 'b': input.numpy()})[0] 
 
assert np.allclose(torch_output, ort_output) 

在这份代码中,我们直接把 MyAdd 作为要导出的模型。我们计算了一个 PyTorch 模型的运行结果,又导出 ONNX 模型,计算了 ONNX 模型在 ONNX Runtime 上的运算结果。如果一切正常的话,这两个结果是一样的,这份代码不会报任何错误,没有任何输出。

可视化一下 my_add.onnx,可以看出,和我们设计得一样,my_add 算子被翻译成了两个 ONNX 算子节点(其中常量算子被放入了 Mul 的参数中)。

整理一下,整个流程的 Python 代码如下:

import torch 
import my_lib 
class MyAddFunction(torch.autograd.Function): 
 
    @staticmethod 
    def forward(ctx, a, b): 
        return my_lib.my_add(a, b) 
 
    @staticmethod 
    def symbolic(g, a, b): 
        two = g.op("Constant", value_t=torch.tensor([2])) 
        a = g.op('Mul', a, two) 
        return g.op('Add', a, b) 
 
my_add = MyAddFunction.apply 
 
class MyAdd(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, a, b): 
        return my_add(a, b) 
 
model = MyAdd() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, (input, input), 'my_add.onnx') 
torch_output = model(input, input).detach().numpy() 
 
import onnxruntime 
import numpy as np 
sess = onnxruntime.InferenceSession('my_add.onnx') 
ort_output = sess.run(None, {'a': input.numpy(), 'b': input.numpy()})[0] 
 
assert np.allclose(torch_output, ort_output) 

总结

在这篇教程中,我们围绕“为 ATen 算子添加符号函数”、“为 TorchScript 算子添加符号函数”、“封装成 torch.autograd.Function 并添加符号函数”这三种添加映射关系的方法,讲解了 3 个为 PyTorch 和 ONNX 添加支持的实例。在这个过程中,我们学到了很多零散的知识,来总结一下吧。

  • ATen 是 PyTorch 的 C++ 张量运算库。通过查询 torch/_C/_VariableFunctions.pyi 和 torch/nn/functional.pyi,我们可以知道 ATen 算子的 Python 接口定义。
  • 用 register_op 可以为 ATen 算子补充注册符号函数
  • 用 register_custom_op_symbolic 可以为 TorchScript 算子补充注册符号函数
  • 如何在 PyTorch 里添加 C++ 拓展
  • 如何用 torch.autograd.Function 封装一个自定义 PyTorch 算子
  • 如何编写符号函数 symbolic(g, ...)
  • 如何用 g.op() 把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。

python 深浅拷贝

拷贝是Python学习过程中很容易被忽略,但是在项目开发过程中起着重要作用的一个概念。

有很多开发者由于忽视这一点,甚至导致项目中出现很严重的BUG。

我之前就因为这样的一个小问题,一不小心掉坑里了。反复定位才发现竟然是由这个容易被忽视的问题引起的….

在这篇文章中,我们将看看如何在Python中深度和浅度拷贝对象,深入探讨Python 如何处理对象引用和内存中的对象。

浅拷贝

当我们在 Python 中使用赋值语句 (=) 来创建复合对象的副本时,例如,列表或类实例或基本上任何包含其他对象的对象,Python 并没有克隆对象本身。

相反,它只是将引用绑定到目标对象上。

想象一下,我们有一个列表,里面有以下元素。

original_list =[[1,2,3], [4,5,6], ["X", "Y", "Z"]]

如果我们尝试使用如下的赋值语句来复制我们的原始列表。

shallow_copy_list = original_list
print(shallow_copy_list)

它可能看起来像我们克隆了我们的对象,或许很多同学会认为生成了两个对象,

[[1,2,3], [4,5,6], ['X', 'Y', 'Z']]

但是,我们真的有两个对象吗?

不,并没有。我们有两个引用变量,指向内存中的同一个对象。通过打印这两个对象在内存中的ID,可以很容易地验证这一点。

id(original_list) # 4517445712
id(shallow_copy_list) # 4517445712

一个更具体的证明可以通过尝试改变 “两个列表”中的一个值来观察–而实际上,我们改变的是同一个列表,两个指针指向内存中的同一个对象。

让我们来改变original_list所指向的对象的最后一个元素。

# Last element of last element
original_list[-1][-1] = "ZZZ"
print(original_list)

输出结果是:

[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]

两个引用变量都指向同一个对象,打印shallow_copy_list将返回相同的结果。

print(shallow_copy_list) # [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]

浅层复制是指复制一个对象的引用并将其存储在一个新的变量中的过程。original_list和shallow_copy_list只是指向内存(RAM)中相同地址的引用,这些引用存储了[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']的值。

我们在复制过程中,并没有生成一个新的对象,试想一下,如果不理解这一点,很多同学会误认为它生成了一个完全独立的新对象,殊不知,在对这个新变量shallow_copy_list进行操作时,原来的变量original_list也会跟随改变。

除了赋值语句之外,还可以通过Python标准库的拷贝模块实现浅拷贝

要使用拷贝模块,我们必须首先导入它。

import copy
second_shallow_copy_list = copy.copy(original_list)

把它们都打印出来,看看它们是否引用了相同的值。

print(original_list)
print(second_shallow_copy_list)

不出所料,确实如此,

[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]
[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]

通常,你想复制一个复合对象,例如在一个方法的开始,然后修改克隆的对象,但保持原始对象的原样,以便以后再使用它。

为了达到这个目的,我们需要对该对象进行深度复制。现在让我们来学习一下什么是深度拷贝以及如何深度拷贝一个复合对象。

深拷贝

深度复制一个对象意味着真正地将该对象和它的值克隆到内存中的一个新的副本(实例)中,并具有这些相同的值。

通过深度拷贝,我们实际上可以创建一个独立于原始数据的新对象,但包含相同的值,而不是为相同的值创建新的引用。

在一个典型的深度拷贝过程中,首先,一个新的对象引用被创建,然后所有的子对象被递归地加入到父对象中。

这样一来,与浅层拷贝不同,对原始对象的任何修改都不会反映在拷贝对象中(反之亦然)。

下面是一个典型的深度拷贝的简单图示。

要在 Python 中深度拷贝一个对象,我们使用 copy 模块的 deepcopy()方法。

让我们导入 copy 模块并创建一个列表的深度拷贝。

import copy
 
original_list = [[1,2,3], [4,5,6], ["X", "Y", "Z"]]
deepcopy_list = copy.deepcopy(original_list)

现在让我们打印我们的列表,以确保输出是相同的,以及他们的ID是唯一的。

print(id(original_list), original_list)
print(id(deepcopy_list), deepcopy_list)

输出结果证实,我们已经为自己创建了一个真正的副本。

4517599280, [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'Z']]
4517599424, [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'Z']]

现在让我们试着修改我们的原始列表,把最后一个列表的最后一个元素改为 “O”,然后打印出来看看结果。

original_list[-1][-1] = "O"
print(original_list)

我们得到了预期的结果。

[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'O']]

现在,如果我们继续前进并尝试打印我们的副本列表,之前的修改并没有影响新的变量。

print(deepcopy_list) # [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'Z']]

记住,copy()deepcopy()方法适用于其他复合对象。这意味着,你也可以用它们来创建类实例的副本。

python 类型注释 # type

Type Comments[类型注解]

注释是在Python 3中引入的,并且它们没有被反向移植到Python 2.这意味着如果您正在编写需要支持旧版Python的代码,则无法使用注释。

要向函数添加类型注释,您可以执行以下操作:

import math 
def circumference(radius):    
# type: (float) -> float    
   return 2 * math.pi * radius

类型注释只是注释,所以它们可以用在任何版本的Python中。

类型注释由类型检查器直接处理,所以不存在__annotations__字典对象中:

>>> circumference.__annotations__{}

类型注释必须以type: 字面量开头,并与函数定义位于同一行或下一行。如果您想用几个参数来注释一个函数,您可以用逗号分隔每个类型:

def headline(text, width=80, fill_char="-"):  
  # type: (str, int, str) -> str    
   return f" {text.title()} ".center(width, fill_char) 

print(headline("type comments work", width=40))

您还可以使用自己的注释在单独的行上编写每个参数:

# headlines.py
 
  def headline(
      text,           # type: str
      width=80,       # type: int
      fill_char="-",  # type: str
  ):                  # type: (...) -> str
      return f" {text.title()} ".center(width, fill_char)
 
 print(headline("type comments work", width=40))

通过Python和Mypy运行示例:

$  python headlines.py
---------- Type Comments Work ---------- 
$ mypy headline.py
$

如果传入一个字符串width=”full”,再次运行mypy会出现一下错误。

$ mypy headline.py
headline.py:10: error: Argument "width" to "headline" has incompatible
                       type "str"; expected "int"

您还可以向变量添加类型注释。这与您向参数添加类型注释的方式类似:

pi = 3.142  # type: float

上面的例子可以检测出pi是float类型。

Python爬虫:常用的爬虫工具汇总

最近需要跑一个风格迁移cyclegan项目,这个并不难,github上随便search一个就可以,但是数据集很是头疼,没有比较合适的数据集,因此需要自己在网上寻找一些图片,但如果不使用爬虫爬数据,不知道要到猴年马月,因此需要使用爬虫爬取谷歌、百度、以及一些图片网站的图片,之前倒是学过request库,但没怎么用过,因此先开个帖子,记录下相关知识。

爬虫整体思路:页面下载 –> 页面解析 –> 数据存储

一、页面下载器

 requests(必学)
      
     Requests: HTTP for Humans™
  1. python爬虫入门requests模块
  2. Python爬虫:requests库基本使用
  3. Python爬虫:使用requests库下载大文件
  4. Python爬虫:requests多进程爬取猫眼电影榜单
  5. requests InsecureRequestWarning: Unverified HTTPS request is being made.
  1. scrapy
    1. Python网络爬虫之scrapy框架
    2. scrapy学习
    3. Python爬虫:关于scrapy模块的请求头
    4. Python爬虫:scrapy框架请求参数meta、headers、cookies一探究竟
    5. Python爬虫:scrapy辅助功能实用函数
  2. selenium+chrome + PhantomJS(抓取动态网页,不推荐)
    1. mac下安装selenium+phantomjs+chromedriver
    2. Python爬虫:selenium模块基本使用
    3. Python爬虫selenium模块
    4. Python爬虫:selenium和Chrome无头浏览器抓取烯牛数据动态网页
    5. Python爬虫:利用selenium爬取淘宝商品信息
    6. Python爬虫:selenium使用chrome和PhantomJS实用参数
  1. Splash(抓取动态网页,推荐)
    1. Python爬虫:splash的安装与简单示例
    2. Python爬虫:splash+requests简单示例
    3. Python爬虫:scrapy利用splash爬取动态网页

总结: 对于下载器而言,python自带的urllib就不要花时间去学了,学了就忘,直接requests能满足大部分测试+抓取需求,进阶工程化scrapy,动态网页优先找API接口,如果有简单加密就破解,实在困难就使用splash渲染

二、页面解析器

  1. BeautifulSoup(入门级)
    1. Python爬虫入门BeautifulSoup模块
    2. Beautiful Soup 4.4.0 文档¶
  1. pyquery (类似jQuery)
    1. Python爬虫:pyquery模块解析网页
  2. lxml
    1. Python爬虫:使用lxml解析网页内容
  1. parsel
    1. Extract text using CSS or XPath selectors
  2. scrapy的Selector (强烈推荐, 比较高级的封装,基于parsel)
    1. 选择器(Selectors)
    2. python爬虫:scrapy框架xpath和css选择器语法

总结: 其实解析器学习一个就够了,其他都不用学,很多培训会教你从上到下的学习,我不是很推荐,直接学习scrapy的Selector 就行,简单、直接、高效

三、数据存储

  1. txt文本
    1. Python全栈之路:文件file常用操作
  1. csv文件
    1. python读取写入csv文件
  2. sqlite3 (python自带)
    1. Python编程:使用数据库sqlite3
  1. MySQL
    1. SQL:pymysql模块读写mysql数据
  2. MongoDB
    1. Python编程:mongodb的基本增删改查操作

总结: 数据存储没有什么可深究的,按照业务需求来就行,一般快速测试使用MongoDB,业务使用MySQL

四、其他工具

  1. execjs :执行js Python爬虫:execjs在python中运行javascript代码
  2. pyv8: 执行js mac安装pyv8模块-JavaScript翻译成python
  3. html5lib 1. Python爬虫:scrapy利用html5lib解析不规范的html文本

python 爬取网站图片

对于做人工智能来说,最主要的爬取目标是图片,需要在网上获取大量的图片数据用于模型训练。这里参考网上资料,自己写一个简单的爬虫程序。

1、爬取百度图片:

百度图片比较简单,通过一个ajax请求,来获取图片的url:

参数:

2、爬取 谷歌图片:

谷歌跟百度不同,需要使用 selenium

由于google图片界面是属于那种往下划会在本页面中加载出更多信息,但未刷新的机制,但是它又并未使用ajax。
所以这里我们使用selenium。selenium是一个能够模拟浏览器的工具,如果你没有安装,请pip install 一下。
然后是下载符合你的浏览器的驱动,我这里用的是Chrome,所以下载了ChromeDriver,将其放在D:\python\Scripts(你的python安装目录)。
用这两个来模拟用户的浏览器操作。


from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.chrome.options import Options
import time
import os
import urllib.request
import uuid

def download_pic(url, name, path):

    if not os.path.exists(path):
        os.makedirs(path)
    res = urllib.request.urlopen(url, timeout=3).read()
    with open(path + name +'.jpg', 'wb') as file:
        file.write(res)
        file.close()

def get_image_url(num, key_word):

    box = driver.find_element_by_xpath('/html/body/div[1]/div[3]/form/div[1]/div[1]/div[1]/div/div[2]/input')
    box.send_keys(key_word)
    box.send_keys(Keys.ENTER)
    box = driver.find_element_by_xpath('//*[@id="hdtb-msb"]/div[1]/div/div[2]/a').click()

    # 滚动页面
    last_height = driver.execute_script('return document.body.scrollHeight')
    while True:
        driver.execute_script('window.scrollTo(0,document.body.scrollHeight)')
        time.sleep(2)
        new_height = driver.execute_script('return document.body.scrollHeight')
        try:
            driver.find_elements_by_xpath('//*[@id="islmp"]/div/div/div/div/div[5]/input').click()
        except:
            pass
        if new_height == last_height:
            # 点击显示更多结果
            try:
                box = driver.find_element_by_xpath('//*[@id="islmp"]/div/div/div/div[1]/div[2]/div[2]/input').click()
            except:
                break
        last_height = new_height

    image_urls = []

    for i in range(1, num):
        try:
            image = driver.find_element_by_xpath('//*[@id="islrg"]/div[1]/div[' + str(i) + ']/a[1]/div[1]/img')
            # 此选项为下载缩略图
            # image_src = image.get_attribute("src")
            image.click() # 点开大图
            time.sleep(4)  # 因为谷歌页面是动态加载的,需要给予页面加载时间,否则无法获取原图url,如果你的网络状况一般请适当延长
            # 获取原图的url
            image_real = driver.find_element_by_xpath('//*[@id="Sva75c"]/div/div/div[3]/div[2]/c-wiz/div/div[1]/div[1]/div[2]/div[1]/a/img')
            image_url = image_real.get_attribute("src")
            image_urls.append(image_url)
            print(str(i) + ': ' + image_url)
        except:
            print(str(i) + ': error')
            pass
    return image_urls
if __name__ == '__main__':
    # 创建一个参数对象,用来控制chrome是否以无界面模式打开
    ch_op = Options()
    # 设置谷歌浏览器的页面无可视化,如果需要可视化请注释这两行代码
    ch_op.add_argument('--headless')
    ch_op.add_argument('--disable-gpu')

    url = "https://www.google.com/"
    driver = webdriver.Chrome(r'D:\anconda3\chromedriver.exe', options=ch_op)
    driver.get(url)

    key_word = input('请输入关键词:')
    num = int(input('请输入需要下载的图片数:'))
    _path = input('请输入图片保存路径,例如G:\\\\google\\\\images\\\\ :')

    # path = "G:\\google\\images_download\\" + key_word + "\\"  # 图片保存路径改为自己的路径
    path = _path + key_word + "\\"
    print('正在获取图片url...')
    image_urls = get_image_url(num, key_word)
    for index, url in enumerate(image_urls):
        try:
            print('第' + str(index) + '张图片开始下载...')
            download_pic(url, str(uuid.uuid1()), path)
        except Exception as e:
            print(e)
            print('第' + str(index) + '张图片下载失败')
            continue
    driver.quit()

python 异常处理 try except 和 断言(assert)

最近在写代码的时候,很多时候需要考虑各种情况,如果仅仅使用if,会很麻烦,于是想到了python 异常处理和断言,用于判断函数进程。、

异常处理 try

程序在运行的时候,如果python解释器遇到一个错误,会停止程序的执行,
并且提示一些错误的信息,这就是异常
我们在程序开发的时候,很难将所有的特殊情况都处理,
通过异常捕获可以针对
突发事件做集中处理,从而保证程序的健壮性和稳定性

在程序开发中,如果对某些代码的执行不能确定(程序语法完全正确)
可以增加try来捕获异常

try这个关键字来捕获异常
try:尝试执行的代码
except:出现错误的处理 finally:无论是否发生异常,都会执行final部份

try:
    print('try...')
    r = 10 / int('a')
    print('result:', r)
except ValueError as e:
    print('ValueError:', e)
except ZeroDivisionError as e:
    print('ZeroDivisionError:', e)
finally:
    print('finally...')
print('END')

try 语句的工作原理如下:

  • 首先,执行 try 子句 (try 和 except 关键字之间的(多行)语句)。
  • 如果没有触发异常,则跳过 except 子句try 语句执行完毕。
  • 如果在执行 try 子句时发生了异常,则跳过该子句中剩下的部分。 如果异常的类型与 except 关键字后指定的异常相匹配,则会执行 except 子句,然后跳到 try/except 代码块之后继续执行。
  • 如果发生的异常与 except 子句 中指定的异常不匹配,则它会被传递到外部的 try 语句中;如果没有找到处理程序,则它是一个 未处理异常 且执行将终止并输出如上所示的消息。

try 语句可以有多个 except 子句 来为不同的异常指定处理程序。 但最多只有一个处理程序会被执行。 处理程序只处理对应的 try 子句 中发生的异常,而不处理同一 try 语句内其他处理程序中的异常。 except 子句 可以用带圆括号的元组来指定多个异常。

常见异常:

try 语句还有一个可选子句,用于定义在所有情况下都必须要执行的清理操作。

如果存在 finally 子句,则 finally 子句是 try 语句结束前执行的最后一项任务。不论 try 语句是否触发异常,都会执行 finally 子句。以下内容介绍了几种比较复杂的触发异常情景:

  • 如果执行 try 子句期间触发了某个异常,则某个 except 子句应处理该异常。如果该异常没有 except 子句处理,在 finally 子句执行后会被重新触发。
  • except 或 else 子句执行期间也会触发异常。 同样,该异常会在 finally 子句执行之后被重新触发。
  • 如果 finally 子句中包含 breakcontinue 或 return 等语句,异常将不会被重新引发。
  • 如果执行 try 语句时遇到 break,、continue 或 return 语句,则 finally 子句在执行 breakcontinue 或 return 语句之前执行。
  • 如果 finally 子句中包含 return 语句,则返回值来自 finally 子句的某个 return 语句的返回值,而不是来自 try 子句的 return 语句的返回值。

assert(断言)

Python assert(断言)用于判断一个表达式,在表达式条件为 false 的时候触发异常。

断言可以在条件不满足程序运行的情况下直接返回错误,而不必等待程序运行后出现崩溃的情况,例如我们的代码只能在 Linux 系统下运行,可以先判断当前系统是否符合条件。

语法格式如下:

assert expression

等价于:

if not expression:
    raise AssertionError

assert 后面也可以紧跟参数:

assert expression [, arguments]

等价于:

if not expression:
    raise AssertionError(arguments)

中文文本清洗与特征提取

摘自知乎:

bookname嵌入式AI算法研究

中文文本清洗

中文文本清洗:

– 去除指定无用的符号

– 让文本只保留汉字

– 文本中的表情符号去除

– 繁体中文与简体中文转换

中文文本清洗类

import re
from opencc import OpenCC
from bs4 import BeautifulSoup
import jieba
from glob import glob

import torch
from tqdm.auto import tqdm

import sys
!ls ../package/
sys.path.insert(0, "../package/")
from ltp import LTP
nlp = LTP(path="base")

class TextCleaner:
    '''
        批量清洗数据
    '''
    def __init__(self,
                 remove_space=True, # 去除空格
                 remove_suspension=True, # 转换省略号
                 only_zh=False, # 只保留汉子
                 remove_sentiment_character=True, # 去除表情符号
                 to_simple=True, # 转化为简体中文
                 remove_html_label=True,
                 remove_stop_words=False,
                 stop_words_dir="./停用词/",
                 with_space=False,
                 batch_size=256):
        self._remove_space = remove_space
        self._remove_suspension = remove_suspension
        self._remove_sentiment_character = remove_sentiment_character

        self._only_zh = only_zh
        self._to_simple = to_simple

        self._remove_html_label = remove_html_label
        self._remove_stop_words = remove_stop_words
        self._stop_words_dir = stop_words_dir

        self._with_space = with_space
        self._batch_size = batch_size

    def clean_single_text(self, text):
        if self._remove_space:
            text = self.remove_space(text)
        if self._remove_suspension:
            text = self.remove_suspension(text)
        if self._remove_sentiment_character:
            text = self.remove_sentiment_character(text)
        if self._to_simple:
            text = self.to_simple(text)
        if self._only_zh:
            text = self.get_zh_only(text)
        if self._remove_html_label:
            text = self.remove_html(text)
        return text

    def clean_text(self, text_list):
        text_list = [self.clean_single_text(text) for text in tqdm(text_list)]
        tokenized_words_list = self.tokenizer_batch_text(text_list)
        if self._remove_stop_words:
            text_list = [self.remove_stop_words(words_list, self._stop_words_dir, self._with_space) for words_list in tokenized_words_list]
        return text_list

    def remove_space(self, text):     #定义函数
        return text.replace(' ','')   # 去掉文本中的空格

    def remove_suspension(self, text):
        return text.replace('...', '。')

    def get_zh_only(self, text):
        def is_chinese(uchar):
            if uchar >= u'\u4e00' and uchar <= u'\u9fa5':  # 判断一个uchar是否是汉字 中文字符的编码范围 \u4e00 - \u9fff,只要在这个范围就可以
                return True
            else:
                return False

        content = ''
        for i in text:
            if is_chinese(i):
                content = content+i
        return content

    def remove_sentiment_character(self, sentence):    
        pattern = re.compile("[^\u4e00-\u9fa5^,^.^!^,^。^?^?^!^a-z^A-Z^0-9]")  #只保留中英文、数字和符号,去掉其他东西
        #若只保留中英文和数字,则替换为[^\u4e00-\u9fa5^a-z^A-Z^0-9]
        line = re.sub(pattern,'',sentence)  #把文本中匹配到的字符替换成空字符
        new_sentence=''.join(line.split())    #去除空白
        return new_sentence

    def to_simple(self, sentence):
        new_sentence = OpenCC('t2s').convert(sentence)   # 繁体转为简体
        return new_sentence

    def to_tradition(self, sentence):
        new_sentence = OpenCC('s2t').convert(sentence)   # 简体转为繁体
        return new_sentence

    def remove_html(self, text):
        return BeautifulSoup(text, 'html.parser').get_text() #去掉html标签

    def tokenizer_batch_text(self, text_list):
        tokenized_text = []
        len_text = len(text_list)
        with torch.no_grad():
            steps = self._batch_size
            for start_idx in tqdm(range(0, len_text, steps)):
                if start_idx + steps > len_text:
                    tokenized_text += nlp.seg(text_list[start_idx:])[0]
                else:
                    tokenized_text += nlp.seg(text_list[start_idx:start_idx+steps])[0]
        return tokenized_text

    def remove_stop_words(self, words_list, stop_words_dir, with_space=False):
        """
        中文数据清洗  stopwords_chineses.txt存放在博客园文件中
        :param text:
        :return:
        """
        stop_word_filepath_list = glob(stop_words_dir + "/*.txt")
        for stop_word_filepath in stop_word_filepath_list:
            with open(stop_word_filepath) as fp:
                stopwords = {}.fromkeys([line.rstrip() for line in fp]) #加载停用词(中文)
        eng_stopwords = set(stopwords) #去掉重复的词
        words = [w for w in words_list if w not in eng_stopwords] #去除文本中的停用词
        if with_space:
            return ' '.join(words)
        else:
            return ''.join(words)
ltp


file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
cleaner = TextCleaner(remove_stop_words=True, with_space=True)
contents = ['   大家好, 欢迎一起来学习文本的空格   去除   !', '   大家好,文本的空格   去除   !']
results = cleaner.clean_text(contents)
print(results)
0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]


['好 , 学习 文本 空格 去除 !', '好 , 文本 空格 去除 !']

去除空格

# 去除空格
contents = '   大家好, 欢迎一起来学习文本的空格   去除   !'
print('处理前文本:'+contents)
def process(our_data):     #定义函数
    content = our_data.replace(' ','')   # 去掉文本中的空格
    print('处理后文本:'+content)
process(contents)
处理前文本:   大家好, 欢迎一起来学习文本的空格   去除   !
处理后文本:大家好,欢迎一起来学习文本的空格去除!

去除空格的同时把省略号转换为句号

# 去除空格的同时把省略号转换为句号
contents = '   大家好, 这里还有  很多的知识...一起拉学习吧 !'
print('处理前文本:'+contents)
def process(data):     #定义函数
    content1 = data.replace(' ','')    # 去掉文本中的空格
    content2 = content1.replace('...','。')    # 去掉文本中的空格
    print('处理后文本:'+ content2)
process(contents)
处理前文本:   大家好, 这里还有  很多的知识...一起拉学习吧 !
处理后文本:大家好,这里还有很多的知识。一起拉学习吧!

让文本只保留汉字

def is_chinese(uchar):
    if uchar >= u'\u4e00' and uchar <= u'\u9fa5':  # 判断一个uchar是否是汉字
        return True
    else:
        return False

def allcontents(contents):
    content = ''
    for i in contents:
        if is_chinese(i):
            content = content+i
    print('\n处理后的句子为:\n'+content)

centents = '1,2,3...我们开始吧, 加油!'
print('原句子为:\n'+centents)
allcontents(centents)
原句子为:
1,2,3...我们开始吧, 加油!

处理后的句子为:
我们开始吧加油

文本中的表情符号去除

import re
sentence='现在听着音乐,duo rui mi,很开心*_*'
print('原句子为:\n'+sentence)

def clear_character(sentence):    
    pattern = re.compile("[^\u4e00-\u9fa5^,^.^!^a-z^A-Z^0-9]")  #只保留中英文、数字和符号,去掉其他东西
    #若只保留中英文和数字,则替换为[^\u4e00-\u9fa5^a-z^A-Z^0-9]
    line=re.sub(pattern,'',sentence)  #把文本中匹配到的字符替换成空字符
    new_sentence=''.join(line.split())    #去除空白
    print('\n处理后的句子为:\n'+new_sentence) 

clear_character(sentence)
原句子为:
现在听着音乐,duo rui mi,很开心*_*

处理后的句子为:
现在听着音乐,duoruimi,很开心

繁体中文与简体中文转换

from opencc import OpenCC

sentence = '你现在读的这里是简体,这里是繁体,能看懂吗?'
print('原句子为:\n'+sentence)

def Simplified(sentence):
    new_sentence = OpenCC('t2s').convert(sentence)   # 繁体转为简体
    print('\n处理后的句子为:\n'+new_sentence)

def Traditional(sentence):
    new_sentence = OpenCC('s2t').convert(sentence)   # 简体转为繁体
    print('\n处理后的句子为:\n'+new_sentence) 

Simplified(sentence)
Traditional(sentence)
原句子为:
你现在读的这里是简体,这里是繁体,能看懂吗?

处理后的句子为:
你现在读的这里是简体,这里是繁体,能看懂吗?

处理后的句子为:
你现在读的这里是简体,这里是繁体,能看懂吗?

OpenCC的参数设置:

- hk2s: Traditional Chinese (Hong Kong standard) to Simplified Chinese
- s2hk: Simplified Chinese to Traditional Chinese (Hong Kong standard)
- s2t: Simplified Chinese to Traditional Chinese
- s2tw: Simplified Chinese to Traditional Chinese (Taiwan standard)
- s2twp: Simplified Chinese to Traditional Chinese (Taiwan standard, with phrases)
- t2hk: Traditional Chinese to Traditional Chinese (Hong Kong standard)
- t2s: Traditional Chinese to Simplified Chinese
- t2tw: Traditional Chinese to Traditional Chinese (Taiwan standard)
- tw2s: Traditional Chinese (Taiwan standard) to Simplified Chinese
- tw2sp: Traditional Chinese (Taiwan standard) to Simplified Chinese (with phrases)

去除html标签和停用词

from bs4 import BeautifulSoup
import jieba
from glob import glob

def clean_chineses_text(text, with_space=False):
    """
    中文数据清洗  stopwords_chineses.txt存放在博客园文件中
    :param text:
    :return:
    """
    text = BeautifulSoup(text, 'html.parser').get_text() #去掉html标签
    text = jieba.lcut(text)
    stop_word_filepath_list = glob("./停用词/*.txt")
#     print(stop_word_filepath_list)
    for stop_word_filepath in stop_word_filepath_list:
        with open(stop_word_filepath) as fp:
            stopwords = {}.fromkeys([line.rstrip() for line in fp]) #加载停用词(中文)
    eng_stopwords = set(stopwords) #去掉重复的词
    words = [w for w in text if w not in eng_stopwords] #去除文本中的停用词
    if with_space:
        return ' '.join(words)
    else:
        return ''.join(words)
clean_chineses_text("你现在读的这里是简体,这里是繁体,能看懂吗?", with_space=True)
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.703 seconds.
Prefix dict has been built successfully.





'读 简体 , 这里 繁体 , 能看懂 吗 ?'
ENGLISH_STOP_WORDS = frozenset([
    "about", "above", "across", "after", "afterwards", "again", "against",
    "all", "almost", "alone", "along", "already", "also", "although", "always",
    "am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
    "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
    "around", "as", "at", "back", "be", "became", "because", "become",
    "becomes", "becoming", "been", "before", "beforehand", "behind", "being",
    "below", "beside", "besides", "between", "beyond", "bill", "both",
    "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con",
    "could", "couldnt", "cry", "de", "describe", "detail", "do", "done",
    "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else",
    "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone",
    "everything", "everywhere", "except", "few", "fifteen", "fifty", "fill",
    "find", "fire", "first", "five", "for", "former", "formerly", "forty",
    "found", "four", "from", "front", "full", "further", "get", "give", "go",
    "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter",
    "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his",
    "how", "however", "hundred", "ie", "if", "in", "inc", "indeed",
    "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter",
    "latterly", "least", "less", "ltd", "made", "many", "may", "me",
    "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly",
    "move", "much", "must", "my", "myself", "name", "namely", "neither",
    "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone",
    "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on",
    "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our",
    "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps",
    "please", "put", "rather", "re", "same", "see", "seem", "seemed",
    "seeming", "seems", "serious", "several", "she", "should", "show", "side",
    "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone",
    "something", "sometime", "sometimes", "somewhere", "still", "such",
    "system", "take", "ten", "than", "that", "the", "their", "them",
    "themselves", "then", "thence", "there", "thereafter", "thereby",
    "therefore", "therein", "thereupon", "these", "they", "thick", "thin",
    "third", "this", "those", "though", "three", "through", "throughout",
    "thru", "thus", "to", "together", "too", "top", "toward", "towards",
    "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us",
    "very", "via", "was", "we", "well", "were", "what", "whatever", "when",
    "whence", "whenever", "where", "whereafter", "whereas", "whereby",
    "wherein", "whereupon", "wherever", "whether", "which", "while", "whither",
    "who", "whoever", "whole", "whom", "whose", "why", "will", "with",
    "within", "without", "would", "yet", "you", "your", "yours", "yourself",
    "yourselves", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l",
    "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"])

特征抽取

  • BOW
  • TF-IDF
  • LDA

文本特征提取类

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer

import sys
!ls ../package/
sys.path.insert(0, "../package/")
from ltp import LTP
nlp = LTP(path="base")

from gensim.models import Word2Vec

class TextFeatures:
    def __init__(self, ngram_range=(1, 2)):
        self.cvt = CountVectorizer(tokenizer=self.tokenizer, ngram_range=ngram_range)
        self.tvt = TfidfVectorizer(tokenizer=self.tokenizer, ngram_range=ngram_range)
        self.hvt = HashingVectorizer(tokenizer=self.tokenizer, ngram_range=ngram_range)
        self.cleaner = TextCleaner(remove_html_label=True, remove_stop_words=True, with_space=True)

    def clean_text(self, text_list):
        return self.cleaner.clean_text(text_list)

    def tokenizer(self, text):
        return text.split(" ")

    def get_bow(self, text_list):
        return self.cvt.fit_transform(text_list)

    def get_tfidf(self, text_list):
        return self.tvt.fit_transform(text_list)

    def get_hashing(self, text_list):
        return self.hvt.fit_transform(text_list)
ltp


file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
train_df = pd.read_csv("../0.数据/1.情感分析/NLPCC14-SC/train.tsv", sep="\t", error_bad_lines=False)
train_df.head()
labeltext_a
set(train_df["label"]), train_df.shape
({0, 1}, (10000, 2))
cleaner = TextCleaner(remove_html_label=True, remove_stop_words=True, with_space=True)
contents = ['   大家好, 欢迎一起来学习文本的空格   去除   !']
results = cleaner.clean_text(contents)
print(results)
0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]


['好 , 学习 文本 空格 去除 !']
tqdm.pandas(desc="clean data")
train_df["cleaned_text"] = cleaner.clean_text(train_df["text_a"].values)
0%|          | 0/10000 [00:00<?, ?it/s]



  0%|          | 0/40 [00:00<?, ?it/s]
train_df.to_csv("cleaned_train.csv", index=None)
# import torch
# from tqdm.auto import tqdm

# tokenized_text = []
# text_list = list(train_df["cleaned_text"].values)
# with torch.no_grad():
#     steps = 256
#     for start_idx in tqdm(range(0, train_df.shape[0], steps)):
# #         print(start_idx)
#         if start_idx + steps > train_df.shape[0]:
#             tokenized_text += nlp.seg(text_list[start_idx:])[0]
#         else:
#             tokenized_text += nlp.seg(text_list[start_idx:start_idx+steps])[0]
# from joblib import dump, load
# 关掉显存占用
# from numba import cuda

# cuda.select_device(0)
# cuda.close()

BOW

!ls ../1.基础/停用词/
中文停用词库.txt  哈工大停用词表.txt  四川大学停用词表.txt  百度停用词表.txt
from glob import glob
# 停用词列表
stop_words = []
txt_list = glob("../1.基础/停用词/*.txt")
for txt_path in txt_list:
    with open(txt_path, "r") as fp:
        lines = fp.readlines()
    stop_words += [line.strip() for line in lines]
len(stop_words)
3893
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
def tokenizer(text):
    return text.split(" ")
# corpus = [" ".join(text_list) for text_list in tokenized_text]
# corpus[:2]
corpus = train_df["cleaned_text"].values
cvt = CountVectorizer(stop_words=stop_words, tokenizer=tokenizer, ngram_range=(1, 2))
x_cvt = cvt.fit_transform(corpus)
len(cvt.vocabulary_)
137525
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(x_cvt, y, test_size=0.1)

clf = Ridge(alpha=500.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.8657380740314067 0.798

valid score: 
0.8009079767378523 0.733

TFIDF

from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
tvt = TfidfVectorizer(stop_words=stop_words, tokenizer=tokenizer, ngram_range=(1, 2))
x_tvt = tvt.fit_transform(corpus)
len(tvt.vocabulary_)
137525
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(x_tvt, y, test_size=0.1)

clf = Ridge(alpha=10.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.9349220324539836 0.8745555555555555

valid score: 
0.7963706773775423 0.728

HashingVectorizer

from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
hvt = HashingVectorizer(stop_words=stop_words, tokenizer=tokenizer, ngram_range=(1, 2))
x_hvt = hvt.fit_transform(corpus)
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(x_hvt, y, test_size=0.1)

clf = Ridge(alpha=1.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.99204728016389 0.969

valid score: 
0.8349841394447204 0.749

LDA

train_df = pd.read_csv("./cleaned_train.csv")
train_df.head()
labeltext_acleaned_text
from glob import glob
# 停用词列表
stop_words = []
txt_list = glob("../1.基础/停用词/*.txt")
for txt_path in txt_list:
    with open(txt_path, "r") as fp:
        lines = fp.readlines()
    stop_words += [line.strip() for line in lines]
len(stop_words)
3893
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, HashingVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
def tokenizer(text):
    return text.split(" ")

corpus = train_df["cleaned_text"].values
corpus = [string if string is not np.nan else "" for string in corpus]
cvt = CountVectorizer(tokenizer=tokenizer, ngram_range=(1, 2))
x_cvt = cvt.fit_transform(corpus)
lda = LatentDirichletAllocation(n_components=32, doc_topic_prior=None, topic_word_prior=None, learning_method='batch', 
                                learning_decay=0.7, learning_offset=50.0, max_iter=10, batch_size=128, evaluate_every=-1, 
                                total_samples=1000000.0, perp_tol=0.1, mean_change_tol=0.001, max_doc_update_iter=100, 
                                n_jobs=None, verbose=0, random_state=402)
docres = lda.fit_transform(x_cvt)
docres.shape
(10000, 32)
y = train_df["label"].values
X_train, X_val, y_train, y_val = train_test_split(docres, y, test_size=0.1)

clf = Ridge(alpha=500.)
clf.fit(X_train, y_train)

print("train score: ")
y_pred = clf.predict(X_train)
print(roc_auc_score(y_train, y_pred), accuracy_score(y_train, y_pred>0.5))
print()
print("valid score: ")
y_pred = clf.predict(X_val)
print(roc_auc_score(y_val, y_pred), accuracy_score(y_val, y_pred>0.5))
train score: 
0.5984059229289742 0.5741111111111111

valid score: 
0.5797141495568878 0.57

gensim

corpus = [string.split(" ") for string in corpus]
from gensim import corpora
dictionary = corpora.Dictionary(corpus)
dictionary.save('qzone.dict')
dictionary.filter_extremes(no_below=20, no_above=0.5)
dictionary.compactify()
corpus = [dictionary.doc2bow(s) for s in corpus]
corpora.MmCorpus.serialize('corpus_bow.mm', corpus)  # 存储语料库
from gensim.models import LdaModel

num_topics = 100
chunksize = 2000
passes = 20
iterations = 400
eval_every = None 

temp = dictionary[0]
id2word = dictionary.id2token

model = LdaModel(
    corpus=corpus,
    id2word=id2word,
    chunksize=chunksize,
    alpha='auto',
    eta='auto',
    iterations=iterations,
    num_topics=num_topics,
    passes=passes,
    eval_every=eval_every
)

model.save('qzone.model')
top_topics = model.top_topics(corpus)
avg_topic_coherence = sum([t[1] for t in top_topics]) / num_topics
print('Average topic coherence: %.4f.' % avg_topic_coherence)
Average topic coherence: -5.7200.
len(top_topics), len(corpus)
(100, 10000)

LTP特征提取

import sys
!ls ../package/

sys.path.insert(0, "../package/")

from ltp import LTP
nlp = LTP(path="base")
ltp


file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
file /root/.cache/torch/ltp/8909177e47aa4daf900c569b86053ac68838d09da28c7bbeb42b8efcb08f56aa-edb9303f86310d4bcfd1ac0fa20a744c9a7e13ee515fe3cf88ad31921ed616b2-extracted/config.json not found
seg, hidden = nlp.seg(["他叫汤姆去拿外衣。"])
pos = nlp.pos(hidden)
ner = nlp.ner(hidden)
srl = nlp.srl(hidden)
dep = nlp.dep(hidden)
sdp = nlp.sdp(hidden)

对于LTP提取的特征,可以参考LTP的文档

  • 静态词向量
  • 动态词向量

python 包、模块的书写 以及 __all__ 变量的用法

一、模块

相信使用过Python编写代码的同学,会经常在文件头看到这样的import …,是的,这就是导入模块的语句,而每一个后缀名为.py的文件都是一个模块。

import jieba
import os 

1. 什么是模块?

  逻辑上来说模块是一组功能的组合;实质上一个模块就是一个包含了python定义和声明的文件,文件名就是模块名字加上.py的后缀。

import加载的模块分为四个通用类别:

a. 使用python编写的代码(.py文件);
b. 已被编译为共享库或DLL的C或C++扩展;
c. 包好一组模块的包
d. 使用C编写并链接到python解释器的内置模块;

如何使用模块?
  想要使用模块,必须先要将模块加载进来,可以通过关键字 import 或 from进行加载;需要注意的是模块和当前文件在不同的命名空间中。

2. 模块的构成

  模块可以包含可执行的语句和函数的定义,这些语句的目的是初始化模块,它们只在模块名第一次遇到导入import语句时才执行(import语句是可以在程序中的任意位置使用的,且针对同一个模块很import多次,为了防止你重复导入,python的优化手段是:第一次导入后就将模块名加载到内存了,后续的import语句仅是对已经加载大内存中的模块对象增加了一次引用,不会重新执行模块内的语句

二、模块的导入

1、导入整个模块

  比如我们有一个myModule的文件夹,里面有一个first.py文件,文件中的内容如下

a = 1
def myfun(s):
    print(s + 1)

  在myModule的文件夹下打开终端/cmd,输入python进入命令行交互模式
写完模块导入的语句之后,接着就可以调用该模块下的函数了。调用方式为

>>> import first
>>> a
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'a' is not defined
>>> first.a
1
>>> first.myfun(2)
3

在这里插入图片描述
2、导入特定的函数/变量

  所以说first.py文件就是一个模块,可以用import导入,里面变量和方法都要用first.前缀来引用,如果想不使用这个前缀或是我们只是想要使用模块中的某个函数,就可以只导入该变量或函数。导入方式为:from module_name import function_name。
  如果导入的是变量,就可以直接输入变量名来获得变量的值;如果直接导入的是函数,可以直接使用function_name() 的方式调用函数,无需在函数名前面加上模块名。

# 导入变量
>>> from first import a
>>> a
1
# 导入函数
>>> from first import myfun
>>> myfun(3)
4
# 一次导入多个变量
>>> from first import a,myfun
>>> a
1
>>> myfun(5)
6
# 导入模块中全部变量
>>> from first import *
>>> a
1
>>> myfun(5)
6
>>>

3、使用as给模块指定别名

  可以在后面使用as给函数指定别名。句式如:import module_name as new_name,

>>> import first as f
>>> f.a
1
>>> f.myfun(6)
7

在上述导入函数的基础上,可以在后面用as语句给导入的函数指定别名。句式如:from module_name import function_name as new_function。

>>> from first import myfun as add
>>> add(8)
9

三、包、库

模块(module) 其实就是py文件,里面定义了一些函数、类、变量等。
包(package) 是多个模块的聚合体形成的文件夹,里面可以是多个py文件,也可以嵌套文件夹。
是参考其他编程语言的说法,是指完成一定功能的代码集合,在python中的形式就是模块和包。

一个包的架构:

sound/                          Top-level package
      __init__.py               Initialize the sound package
      formats/                  Subpackage for file format conversions
              __init__.py
              wavread.py
              wavwrite.py
              aiffread.py
              aiffwrite.py
              auread.py
              auwrite.py
              ...
      effects/                  Subpackage for sound effects
              __init__.py
              echo.py
              surround.py
              reverse.py
              ...
      filters/                  Subpackage for filters
              __init__.py
              equalizer.py
              vocoder.py
              karaoke.py
              ...

Python 只把含 __init__.py 文件的目录当成包。这样可以防止以 string 等通用名称命名的目录,无意中屏蔽出现在后方模块搜索路径中的有效模块。 最简情况下,__init__.py 只是一个空文件,但该文件也可以执行包的初始化代码,或设置 __all__ 变量,详见下文。

四、包的导入

导入包的本质:导入一个包就是执行包下的__init__.py文件

只要一个文件夹下面有个__init__.py 文件,那么这个文件夹就可以看做是一个包

包导入的过程和模块的基本一致,只是导入包的时候会执行此包目录下的 init.py 而不是模块里面的语句了。另外,如果只是单纯的导入包,而包的 init.py 中又没有明确的其他初始化操作,那么此包下面的模块是不会自动导入的。

另外需要注意两点

  1. __ init__ .py文件编写时,如果要在__init__.py中导入其他模块中的变量,即使__ init__.py文件和abcd.py文件在同一个文件夹下,也不能from abcd import b,要从abcd文件从哪里来开始写,即从包的名称开始,from folder.abcd import b。
  2. folder文件夹里的嵌套文件夹内不需要新建__init__.py文件即可像模块一样调用,但是一般还是要新建这个文件,可以方便地导入常用变量。
  3. init.py文件其实是一个特殊的文件,它相当于名为folder模块,即如果使用import folder则可以调用在__init__.py文件文件中定义的变量。

五、__ all __

使用 from sound.effects import * 时会发生什么?理想情况下,该语句在文件系统查找并导入包的所有子模块。这项操作花费的时间较长,并且导入子模块可能会产生不必要的副作用,这种副作用只有在显式导入子模块时才会发生。

唯一的解决方案是提供包的显式索引。import 语句使用如下惯例:如果包的 __init__.py 代码定义了列表 __all__,运行 from package import * 时,它就是用于导入的模块名列表。发布包的新版本时,包的作者应更新此列表。如果包的作者认为没有必要在包中执行导入 * 操作,也可以不提供此列表。例如,sound/effects/__init__.py 文件包含以下代码:

__all__ = ["echo", "surround", "reverse"]

这将意味着将 from sound.effects import * 导入 sound.effects 包的三个命名的子模块。

如果没有定义 __all__from sound.effects import * 语句 不会 把包 sound.effects 中所有子模块都导入到当前命名空间;该语句只确保导入包 sound.effects (可能还会运行 __init__.py 中的初始化代码),然后,再导入包中定义的名称。这些名称包括 __init__.py 中定义的任何名称(以及显式加载的子模块),还包括之前 import 语句显式加载的包里的子模块。

变量__all__的好处:只会导出all中的子模块,可以有效地避免命名空间的污染,并加速模块的导入

一、模块公开接口的一种约定
__all__可以在模块级别暴露接口,形式如下:
__all__ = [“foo”, “bar”]
Python 没有原生的可见性控制,其可见性的维护是靠一套需要大家自觉遵守的”约定“,比如,下划线开头的变量对外部不可见。
__all__ 是针对模块公开接口的一种约定,以提供了”白名单“的形式暴露接口。如果定义了__all__,其他文件中使用from xxx import *导入该文件时,只会导入 __all__ 列出的成员,可以其他成员都被排除在外。
如,test1.py,test2.py,test3.py三个文件:
test1.py
#__all__ = [‘func’]
def func():
pass

test2.py
import test1

__all__ = [‘func2’, ‘test1’]
def func2():
pass

def func22():
pass

test3.py
from test2 import *

func2() #能正常引用
test1.func() #能正常引用
func22() #不能正常引用

二、控制 from xxx import * 的行为
python不提倡用 from xxx import * 这种写法。如果一个模块 xxx 没有定义 __all__,执行 from spam import * 时会将 xxx 中所有非下划线开头的成员(包括该模块import的其他模块成员)都会导入当前命名空间,这样就可能弄脏当前的命名空间。显式声明了 __all__,import * 就只会导入 __all__ 列出的成员,如果 __all__ 定义有误,还会明确地抛出异常,方便检查错误。

三、为 lint 等代码检查工具提供辅助
编写库时,经常会在 __init__.py 中暴露整个包的 API,而这些 API 的实现可能是在包的其他模块中。如果仅仅这样写:from xxx import a, b,一些代码检查工具,如 pyflakes 会报错,认为变量 a和 b import 了但没被使用。一个可行的方法是把这个警告压掉:from xxx import a, b # noqa (No Q/A,即无质量保证),但更好的方法是显式定义 __all__,这样代码检查工具就会理解,从而不再报 unused variables 的警告。

四、定义 all 需要注意的地方

  • __all__ 的形式都是 list类型。如果写成其他类型, pyflakes 等 lint 工具可能无法识别。
  • 不能动态生成 __all__,如使用列表解析式。__all__ 的作用是定义公开接口,需要以字面量的形式显式写出来。
  • 即使定义了 __all__, 也不应该在非临时代码中使用 from xxx import * 语法,或用编程工具模拟 Ruby 的自动 import。Python 不像 Ruby,没有 Module 这类成员,模块就是命名空间隔离的执行者。如果打破了这一层,引入诸多动态因素,生产环境中跑的代码就可能充满不确定性,调试也会变得困难。
  • 按照 PEP8 建议的风格,__all__ 应该写在所有 import 语句下面,函数、常量等成员定义的上面。
  • 如果一个模块需要暴露的接口改动频繁,__all__ 可以这样定义:

__all__ = [
“foo”,
“bar”,
“egg”,
]
这样修改一个暴露的接口只修改一行,方便版本控制的时候看 diff。最后多出的逗号在 Python 中是允许的,符合 PEP8 风格。

由上面的输出结果,我们可以知道import *只会导入__all__中指定的变量,无论是否以下划线开头。这样限制可以防止import *命令导入太多变量污染命名空间,过滤掉一些中间变量如b

五、模块导入的绝对引用与相对引用

python中的import分为绝对引用和相对引用两种。它们之间的差异在于,引用模块时,定位被引用模块位置 的方式不同。

绝对引用是通过.的连接,指定出最高级文件(夹),到目标文件的绝对路径。我们上面的所有用法都属于绝对引用。

而相对引用是指定待引用模块与当前文件的相对位置,.表示上一级文件

  • 绝对引用:from folder.abcd import myclass
  • 相对引用:from .abcd import myclass

在实际使用中,无论是绝对导入还是相对导入都要注意,如何导入与被调用位置有关。

Python装饰器:python中的@符号的作用 以及 torch中经常出现的 @torch.no_grad()

@符号是装饰器(修饰符)的语法糖,在定义函数的时候使用,避免再一次赋值操作

装饰器(Decorators)是 Python 的一个重要部分。简单地说:他们是修改其他函数的功能的函数。他们有助于让我们的代码更简短,也更Pythonic(Python范儿)。大多数初学者不知道在哪儿使用它们,所以我将要分享下,哪些区域里装饰器可以让你的代码更简洁。 首先,让我们讨论下如何写你自己的装饰器。

‘@’符号用作函数修饰符是python2.4新增加的功能,修饰符必须出现在函数定义前一行,不允许和函数定义在同一行。也就是说@A def f(): 是非法的。只可以在模块或类定义层内对函数进行修饰,不允许修饰一个类。一个修饰符就是一个函数,它将被修饰的函数做为参数,并返回修饰后的同名函数或其它可调用的东西。

实例(1):

def spamrun(fn):
   def sayspam(*args):
       print("spam,spam,spam")
   return sayspam
@spamrun
def useful(a,b):
   print (a**2+b**2)

执行: useful(3,4)

返回:spam,spam,spam

def addspam(fn):
   def new(*args):
       print "spam,spam,spam"
       return fn(*args)
   return new

@addspam
def useful(a,b):
   print a**2+b**2

执行: useful(4,3)

结果:

spam,spam,spam

25

@torch.no_grad()

@torch.no_grad()
def eval():
	...

@torch.no_grad()后面的函数的数据不需要计算梯度,也不会进行反向传播

Python装饰器:

装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能,装饰器的返回值也是一个函数对象。它经常用于有切面需求的场景,比如:插入日志、性能测试、事务处理、缓存、权限校验等场景。装饰器是解决这类问题的绝佳设计,有了装饰器,我们就可以抽离出大量与函数功能本身无关的雷同代码并继续重用。概括的讲,装饰器的作用就是为已经存在的对象添加额外的功能。

先来看一个简单例子:

def foo():
    print('i am foo')

现在有一个新的需求,希望可以记录下函数的执行日志,于是在代码中添加日志代码:

def foo():
    print('i am foo')
    logging.info("foo is running")

bar()、bar2()也有类似的需求,怎么做?再写一个logging在bar函数里?这样就造成大量雷同的代码,为了减少重复写代码,我们可以这样做,重新定义一个函数:专门处理日志 ,日志处理完之后再执行真正的业务代码

def use_logging(func):
    logging.warn("%s is running" % func.__name__)
    func()

def bar():
    print('i am bar')

use_logging(bar)

逻辑上不难理解, 但是这样的话,我们每次都要将一个函数作为参数传递给use_logging函数。而且这种方式已经破坏了原有的代码逻辑结构,之前执行业务逻辑时,执行运行bar(),但是现在不得不改成use_logging(bar)。那么有没有更好的方式的呢?当然有,答案就是装饰器。

简单装饰器

def use_logging(func):

    def wrapper(*args, **kwargs):
        logging.warn("%s is running" % func.__name__)
        return func(*args, **kwargs)
    return wrapper

def bar():
    print('i am bar')

bar = use_logging(bar)
bar()

函数use_logging就是装饰器,它把执行真正业务方法的func包裹在函数里面,看起来像bar被use_logging装饰了。在这个例子中,函数进入和退出时 ,被称为一个横切面(Aspect),这种编程方式被称为面向切面的编程(Aspect-Oriented Programming)。

@符号是装饰器的语法糖,在定义函数的时候使用,避免再一次赋值操作

方法一:不用语法糖@符号​​​​​​​

# 装饰器不传入参数时
f = decorator(函数名)

# 装饰器传入参数时
f = (decorator(参数))(函数名)


方法二:采用语法糖@符号​​​​​​​

# 已定义的装饰器
@decorator 
def f():  
    pass

# 执行被装饰过的函数 
f()
def use_logging(func):

    def wrapper(*args, **kwargs):
        logging.warn("%s is running" % func.__name__)
        return func(*args)
    return wrapper

@use_logging
def foo():
    print("i am foo")

@use_logging
def bar():
    print("i am bar")

bar()

如上所示,这样我们就可以省去bar = use_logging(bar)这一句了,直接调用bar()即可得到想要的结果。如果我们有其他的类似函数,我们可以继续调用装饰器来修饰函数,而不用重复修改函数或者增加新的封装。这样,我们就提高了程序的可重复利用性,并增加了程序的可读性。

装饰器在Python使用如此方便都要归因于Python的函数能像普通的对象一样能作为参数传递给其他函数,可以被赋值给其他变量,可以作为返回值,可以被定义在另外一个函数内。

带参数的装饰器

装饰器还有更大的灵活性,例如带参数的装饰器:在上面的装饰器调用中,比如@use_logging,该装饰器唯一的参数就是执行业务的函数。装饰器的语法允许我们在调用时,提供其它参数,比如@decorator(a)。这样,就为装饰器的编写和使用提供了更大的灵活性。

def use_logging(level):
    def decorator(func):
        def wrapper(*args, **kwargs):
            if level == "warn":
                logging.warn("%s is running" % func.__name__)
            return func(*args)
        return wrapper

    return decorator

@use_logging(level="warn")
def foo(name='foo'):
    print("i am %s" % name)

foo()

上面的use_logging是允许带参数的装饰器。它实际上是对原有装饰器的一个函数封装,并返回一个装饰器。我们可以将它理解为一个含有参数的闭包。当我 们使用@use_logging(level=”warn”)调用的时候,Python能够发现这一层的封装,并把参数传递到装饰器的环境中。

类装饰器

再来看看类装饰器,相比函数装饰器,类装饰器具有灵活度大、高内聚、封装性等优点。使用类装饰器还可以依靠类内部的__call__方法,当使用 @ 形式将装饰器附加到函数上时,就会调用此方法。

__call__方法 : 在生成一个类的实例时,自动自行一次call方法

当执行Foo时候生成一个实例,就会自动调用__call__方法

class Foo(object):
    def __init__(self, func):
    self._func = func

def __call__(self):
    print ('class decorator runing')
    self._func()
    print ('class decorator ending')

@Foo
def bar():
    print ('bar')

bar()

functools.wraps

使用装饰器极大地复用了代码,但是他有一个缺点就是原函数的元信息不见了,比如函数的docstring、__name__、参数列表,先看例子:

装饰器

def logged(func):
    def with_logging(*args, **kwargs):
        print func.__name__ + " was called"
        return func(*args, **kwargs)
    return with_logging

函数

@logged
def f(x):
   """does some math"""
   return x + x * x

该函数完成等价于:

def f(x):
    """does some math"""
    return x + x * x
f = logged(f)

不难发现,函数f被with_logging取代了,当然它的docstring,__name__就是变成了with_logging函数的信息了。

print f.__name__    # prints 'with_logging'
print f.__doc__     # prints None

这个问题就比较严重的,好在我们有functools.wraps,wraps本身也是一个装饰器,它能把原函数的元信息拷贝到装饰器函数中,这使得装饰器函数也有和原函数一样的元信息了。

from functools import wraps
def logged(func):
    @wraps(func)
    def with_logging(*args, **kwargs):
        print func.__name__ + " was called"
        return func(*args, **kwargs)
    return with_logging

@logged
def f(x):
    """does some math"""
    return x + x * x

print f.__name__  # prints 'f'
print f.__doc__   # prints 'does some math'

内置装饰器

@staticmathod、@classmethod、@property

@property

把类内方法当成属性来使用,必须要有返回值,相当于getter;

假如没有定义 @func.setter 修饰方法的话,就是只读属性

class Car:

    def __init__(self, name, price):
        self._name = name
        self._price = price    
     
    @property
    def car_name(self):
        return self._name
        
     # car_name可以读写的属性   
     @car_name.setter
     def car_name(self, value):
         self._name = value
         
     # car_price是只读属性 
     @property
     def car_price(self):
         return str(self._price) + '万'
         
benz = Car('benz', 30)

print(benz.car_name)   # benz
benz.car_name = "baojun"
print(benz.car_name)   # baojun
print(benz.car_price)  # 30万

@staticmethod

静态方法,不需要表示自身对象的self和自身类的cls参数,就跟使用函数一样。

静态方法的使用场景:

如果在方法中不需要访问任何实例方法和属性,纯粹地通过传入参数并返回数据的功能性方法,那么它就适合用静态方法来定义,它节省了实例化对象的开销成本,往往这种方法放在类外面的模块层作为一个函数存在也是没问题的,而放在类中,仅为这个类服务。

@classmethod

类方法,不需要self参数,但第一个参数需要是表示自身类的cls参数。

类方法的使用场景有:

作为工厂方法创建实例对象,例如内置模块 datetime.date 类中就有大量使用类方法作为工厂方法,以此来创建date对象。如果希望在方法里面调用静态类,那么把方法定义成类方法是合适的,因为要是定义成静态方法,那么你就要显示地引用类A,这对继承来说可不是一件好事情。

例子

class Demo(object):

    text = "三种方法的比较"
    
    def instance_method(self):
        print("调用实例方法")

    @classmethod
    def class_method(cls):
        print("调用类方法")
        print("在类方法中 访问类属性 text: {}".format(cls.text))
        print("在类方法中 调用实例方法 instance_method: {}".format(cls().instance_method()))

    @staticmethod
    def static_method():
        print("调用静态方法")
        print("在静态方法中 访问类属性 text: {}".format(Demo.text))
        print("在静态方法中 调用实例方法 instance_method: {}".format(Demo().instance_method()))

if __name__ == "__main__":
    # 实例化对象
    d = Demo()
    
    # 对象可以访问 实例方法、类方法、静态方法
    # 通过对象访问text属性
    print(d.text)
    
    # 通过对象调用实例方法
    d.instance_method()
    
    # 通过对象调用类方法
    d.class_method()
    
    # 通过对象调用静态方法
    d.static_method()
    
    # 类可以访问类方法、静态方法
    # 通过类访问text属性
    print(Demo.text)
    
    # 通过类调用类方法
    Demo.class_method()
    
    # 通过类调用静态方法
    Demo.static_method()

@staticmethod 和 @classmethod 的 区别 和 使用场景

在上述例子中,我们可以看出,

区别

在定义静态类方法和类方法时,@staticmethod 装饰的静态方法里面,想要访问类属性或调用实例方法,必须需要把类名写上;

@classmethod装饰的类方法里面,会传一个cls参数,代表本类,这样就能够避免手写类名的硬编码。

在调用静态方法和类方法时,实际上写法都差不多,一般都是通过 类名.静态方法() 或 类名.类方法()。也可以用实例对象调用类方法和静态方法。 对象可以访问 实例方法、类方法、静态方法 , 类可以访问类方法、静态方法

也可以用实例化对象去调用静态方法和类方法但为了和实例方法区分,最好还是用类去调用静态方法和类方法。

使用场景

所以,在定义类的时候,

假如不需要用到与类相关的属性或方法时,就用静态方法@staticmethod

假如需要用到与类相关的属性或方法,然后又想表明这个方法是整个类通用的,而不是对象特异的,就可以使用类方法@classmethod

装饰器的顺序

@a
@b
@c
def f ():

等效于

f = a(b(c(f)))

__call__方法和可调用对象

在看神经网络代码的时候,类定义中总是会出现__call__方法,它是Python类体中可以定义的一个特殊方法,定义了该方法的对象称为可调用对象,即该对象可以像函数一样被调用。

定义一个可以求解自由落体下降高度的对象

类的实例:如果类定义了 __call__ 方法,那么它的实例可以作为函数调用。每执行一次实例,__call__函数就执行一遍:

可以理解为 实例() ==实例.__call__()

class GDistance:
    def __init__(self, g):
        self.g = g
    def __call__(self, t):
        # 自由落体下降距离 s=g*t^2
        return (self.g * t ** 2)/2

调用对象

e_gdist = GDistance(9.8)
for t in range(11):
    print("%d 秒 下降%.2f 米" %(t, e_gdist(t)))

输出结果

0 秒 下降0.00 米
1 秒 下降4.90 米
2 秒 下降19.60 米
3 秒 下降44.10 米
4 秒 下降78.40 米
5 秒 下降122.50 米
6 秒 下降176.40 米
7 秒 下降240.10 米
8 秒 下降313.60 米
9 秒 下降396.90 米
10 秒 下降490.00 米

总之:如果一个类中定义了 __call__ 方法,那么该类的实例可以作为函数调用,并执行__call__方法所定义的内容。

  1. class CLanguage:
  2. # 定义__call__方法
  3. def __call__(self,name,add):
  4. print(“调用__call__()方法”,name,add)
  5. clangs = CLanguage()
  6. clangs(“C语言中文网”,”http://c.biancheng.net”)

程序执行结果为:

调用__call__()方法 C语言中文网 http://c.biancheng.net

可以看到,通过在 CLanguage 类中实现 __call__() 方法,使的 clangs 实例对象变为了可调用对象。

Python 中,凡是可以将 () 直接应用到自身并执行,都称为可调用对象。可调用对象包括自定义的函数、Python 内置函数以及本节所讲的类实例对象。

对于可调用对象,实际上“名称()”可以理解为是“名称.__call__()”的简写。仍以上面程序中定义的 clangs 实例对象为例,其最后一行代码还可以改写为如下形式:

clangs.__call__("C语言中文网","http://c.biancheng.net")