ONNX—-模型部署教程(1)

转自:mmdeploy

模型转换工具 https://convertmodel.com/

官网1:https://pytorch.org/docs/stable/onnx.html#functions

官网2:https://onnxruntime.ai/docs/get-started/

前言

OpenMMLab 的算法如何部署?是很多社区用户的困惑。而模型部署工具箱 MMDeploy 的开源,强势打通了从算法模型到应用程序这 “最后一公里”!

今天我们将开启模型部署入门系列教程,在模型部署开源库 MMDeploy 的辅助下,介绍以下内容:

  • 中间表示 ONNX 的定义标准
  • PyTorch 模型转换到 ONNX 模型的方法
  • 推理引擎 ONNX Runtime、TensorRT 的使用方法
  • 部署流水线 PyTorch – ONNX – ONNX Runtime/TensorRT 的示例及常见部署问题的解决方法
  • MMDeploy C/C++ 推理 SDK

希望通过本系列教程,带领大家学会如何把自己的 PyTorch 模型部署到 ONNX Runtime/TensorRT 上,并学会如何把 OpenMMLab 开源体系中各个计算机视觉任务的模型用 MMDeploy 部署到各个推理引擎上。

初识模型部署

在软件工程中,部署指把开发完毕的软件投入使用的过程,包括环境配置、软件安装等步骤。类似地,对于深度学习模型来说,模型部署指让训练好的模型在特定环境中运行的过程。相比于软件部署,模型部署会面临更多的难题:

1)运行模型所需的环境难以配置。深度学习模型通常是由一些框架编写,比如 PyTorch、TensorFlow。由于框架规模、依赖环境的限制,这些框架不适合在手机、开发板等生产环境中安装。

2)深度学习模型的结构通常比较庞大,需要大量的算力才能满足实时运行的需求。模型的运行效率需要优化。

因为这些难题的存在,模型部署不能靠简单的环境配置与安装完成。经过工业界和学术界数年的探索,模型部署有了一条流行的流水线:

为了让模型最终能够部署到某一环境上,开发者们可以使用任意一种深度学习框架来定义网络结构,并通过训练确定网络中的参数。之后,模型的结构和参数会被转换成一种只描述网络结构的中间表示,一些针对网络结构的优化会在中间表示上进行。最后,用面向硬件的高性能编程框架(如 CUDA,OpenCL)编写,能高效执行深度学习网络中算子的推理引擎会把中间表示转换成特定的文件格式,并在对应硬件平台上高效运行模型。

这一条流水线解决了模型部署中的两大问题:使用对接深度学习框架和推理引擎的中间表示,开发者不必担心如何在新环境中运行各个复杂的框架;通过中间表示的网络结构优化和推理引擎对运算的底层优化,模型的运算效率大幅提升。

中间表示 – ONNX

在介绍 ONNX 之前,我们先从本质上来认识一下神经网络的结构。神经网络实际上只是描述了数据计算的过程,其结构可以用计算图表示。比如 a+b 可以用下面的计算图来表示:

为了加速计算,一些框架会使用对神经网络“先编译,后执行”的静态图来描述网络。静态图的缺点是难以描述控制流(比如 if-else 分支语句和 for 循环语句),直接对其引入控制语句会导致产生不同的计算图。比如循环执行 n 次 a=a+b,对于不同的 n,会生成不同的计算图:

ONNX (Open Neural Network Exchange)是 Facebook 和微软在2017年共同发布的,用于标准描述计算图的一种格式。目前,在数家机构的共同维护下,ONNX 已经对接了多种深度学习框架和多种推理引擎。因此,ONNX 被当成了深度学习框架到推理引擎的桥梁,就像编译器的中间语言一样。由于各框架兼容性不一,我们通常只用 ONNX 表示更容易部署的静态图。

创建 PyTorch 模型

让我们用 PyTorch 实现一个超分辨率模型,并把模型部署到 ONNX Runtime 这个推理引擎上。

# 安装 ONNX Runtime, ONNX, OpenCV 
pip install onnxruntime onnx opencv-python

在一切都配置完毕后,用下面的代码来创建一个经典的超分辨率模型 SRCNN。

import os 
 
import cv2 
import numpy as np 
import requests 
import torch 
import torch.onnx 
from torch import nn 
 
class SuperResolutionNet(nn.Module): 
    def __init__(self, upscale_factor): 
        super().__init__() 
        self.upscale_factor = upscale_factor 
        self.img_upsampler = nn.Upsample( 
            scale_factor=self.upscale_factor, 
            mode='bicubic', 
            align_corners=False) 
 
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4) 
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0) 
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2) 
 
        self.relu = nn.ReLU() 
 
    def forward(self, x): 
        x = self.img_upsampler(x) 
        out = self.relu(self.conv1(x)) 
        out = self.relu(self.conv2(out)) 
        out = self.conv3(out) 
        return out 
 
# Download checkpoint and test image 
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth', 
    'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png'] 
names = ['srcnn.pth', 'face.png'] 
for url, name in zip(urls, names): 
    if not os.path.exists(name): 
        open(name, 'wb').write(requests.get(url).content) 
 
def init_torch_model(): 
    torch_model = SuperResolutionNet(upscale_factor=3) 
 
    state_dict = torch.load('srcnn.pth')['state_dict'] 
 
    # Adapt the checkpoint 
    for old_key in list(state_dict.keys()): 
        new_key = '.'.join(old_key.split('.')[1:]) 
        state_dict[new_key] = state_dict.pop(old_key) 
 
    torch_model.load_state_dict(state_dict) 
    torch_model.eval() 
    return torch_model 
 
model = init_torch_model() 
input_img = cv2.imread('face.png').astype(np.float32) 
 
# HWC to NCHW 
input_img = np.transpose(input_img, [2, 0, 1]) 
input_img = np.expand_dims(input_img, 0) 
 
# Inference 
torch_output = model(torch.from_numpy(input_img)).detach().numpy() 
 
# NCHW to HWC 
torch_output = np.squeeze(torch_output, 0) 
torch_output = np.clip(torch_output, 0, 255) 
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8) 
 
# Show image 
cv2.imwrite("face_torch.png", torch_output)

SRCNN 先把图像上采样到对应分辨率,再用 3 个卷积层处理图像。为了方便起见,我们跳过训练网络的步骤,直接下载模型权重(由于 MMEditing 中 SRCNN 的权重结构和我们定义的模型不太一样,我们修改了权重字典的 key 来适配我们定义的模型),同时下载好输入图片。为了让模型输出成正确的图片格式,我们把模型的输出转换成 HWC 格式,并保证每一通道的颜色值都在 0~255 之间。如果脚本正常运行的话,一幅超分辨率的人脸照片会保存在 “face_torch.png” 中。

在 PyTorch 模型测试正确后,我们来正式开始部署这个模型。我们下一步的任务是把 PyTorch 模型转换成用中间表示 ONNX 描述的模型。

让我们用下面的代码来把 PyTorch 的模型转换成 ONNX 格式的模型:

x = torch.randn(1, 3, 256, 256) 
 
with torch.no_grad(): 
    torch.onnx.export( 
        model, 
        x, 
        "srcnn.onnx", 
        opset_version=11, 
        input_names=['input'], 
        output_names=['output'])

其中,torch.onnx.export 是 PyTorch 自带的把模型转换成 ONNX 格式的函数。让我们先看一下前三个必选参数:前三个参数分别是要转换的模型、模型的任意一组输入、导出的 ONNX 文件的文件名。转换模型时,需要原模型和输出文件名是很容易理解的,但为什么需要为模型提供一组输入呢?这就涉及到 ONNX 转换的原理了。从 PyTorch 的模型到 ONNX 的模型,本质上是一种语言上的翻译。直觉上的想法是像编译器一样彻底解析原模型的代码,记录所有控制流。但前面也讲到,我们通常只用 ONNX 记录不考虑控制流的静态图。因此,PyTorch 提供了一种叫做追踪(trace)的模型转换方法:给定一组输入,再实际执行一遍模型,即把这组输入对应的计算图记录下来,保存为 ONNX 格式。export 函数用的就是追踪导出方法,需要给任意一组输入,让模型跑起来。我们的测试图片是三通道,256×256大小的,这里也构造一个同样形状的随机张量。

剩下的参数中,opset_version 表示 ONNX 算子集的版本。深度学习的发展会不断诞生新算子,为了支持这些新增的算子,ONNX会经常发布新的算子集,目前已经更新15个版本。我们令 opset_version = 11,即使用第11个 ONNX 算子集,是因为 SRCNN 中的 bicubic (双三次插值)在 opset11 中才得到支持。剩下的两个参数 input_names, output_names 是输入、输出 tensor 的名称,我们稍后会用到这些名称。

如果上述代码运行成功,目录下会新增一个”srcnn.onnx”的 ONNX 模型文件。我们可以用下面的脚本来验证一下模型文件是否正确。

import onnx 
 
onnx_model = onnx.load("srcnn.onnx") 
try: 
    onnx.checker.check_model(onnx_model) 
except Exception: 
    print("Model incorrect") 
else: 
    print("Model correct")

其中,onnx.load 函数用于读取一个 ONNX 模型。onnx.checker.check_model 用于检查模型格式是否正确,如果有错误的话该函数会直接报错。我们的模型是正确的,控制台中应该会打印出”Model correct”。

接下来,让我们来看一看 ONNX 模型具体的结构是怎么样的。我们可以使用 Netron (开源的模型可视化工具)来可视化 ONNX 模型。把 srcnn.onnx 文件从本地的文件系统拖入网站,即可看到如下的可视化结果:

点击 input 或者 output,可以查看 ONNX 模型的基本信息,包括模型的版本信息,以及模型输入、输出的名称和数据类型。

点击某一个算子节点,可以看到算子的具体信息。比如点击第一个 Conv 可以看到:

每个算子记录了算子属性、图结构、权重三类信息。

  • 算子属性信息即图中 attributes 里的信息,对于卷积来说,算子属性包括了卷积核大小(kernel_shape)、卷积步长(strides)等内容。这些算子属性最终会用来生成一个具体的算子。
  • 图结构信息指算子节点在计算图中的名称、邻边的信息。对于图中的卷积来说,该算子节点叫做 Conv_2,输入数据叫做 11,输出数据叫做 12。根据每个算子节点的图结构信息,就能完整地复原出网络的计算图。
  • 权重信息指的是网络经过训练后,算子存储的权重信息。对于卷积来说,权重信息包括卷积核的权重值和卷积后的偏差值。点击图中 conv1.weight, conv1.bias 后面的加号即可看到权重信息的具体内容。

现在,我们有了 SRCNN 的 ONNX 模型。让我们看看最后该如何把这个模型运行起来。

推理引擎 -ONNX Runtime

ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器,也就是我们前面提到的”推理引擎“。ONNX Runtime 是直接对接 ONNX 的,即 ONNX Runtime 可以直接读取并运行 .onnx 文件, 而不需要再把 .onnx 格式的文件转换成其他格式的文件。也就是说,对于 PyTorch – ONNX – ONNX Runtime 这条部署流水线,只要在目标设备中得到 .onnx 文件,并在 ONNX Runtime 上运行模型,模型部署就算大功告成了。

通过刚刚的操作,我们把 PyTorch 编写的模型转换成了 ONNX 模型,并通过可视化检查了模型的正确性。最后,让我们用 ONNX Runtime 运行一下模型,完成模型部署的最后一步。

ONNX Runtime 提供了 Python 接口。接着刚才的脚本,我们可以添加如下代码运行模型:

import onnxruntime 
 
ort_session = onnxruntime.InferenceSession("srcnn.onnx") 
ort_inputs = {'input': input_img} 
ort_output = ort_session.run(['output'], ort_inputs)[0] 
 
ort_output = np.squeeze(ort_output, 0) 
ort_output = np.clip(ort_output, 0, 255) 
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8) 
cv2.imwrite("face_ort.png", ort_output)

这段代码中,除去后处理操作外,和 ONNX Runtime 相关的代码只有三行。让我们简单解析一下这三行代码。onnxruntime.InferenceSession用于获取一个 ONNX Runtime 推理器,其参数是用于推理的 ONNX 模型文件。推理器的 run 方法用于模型推理,其第一个参数为输出张量名的列表,第二个参数为输入值的字典。其中输入值字典的 key 为张量名,value 为 numpy 类型的张量值。输入输出张量的名称需要和torch.onnx.export 中设置的输入输出名对应。

如果代码正常运行的话,另一幅超分辨率照片会保存在”face_ort.png”中。这幅图片和刚刚得到的”face_torch.png”是一模一样的。这说明 ONNX Runtime 成功运行了 SRCNN 模型,模型部署完成了!以后有用户想实现超分辨率的操作,我们只需要提供一个 “srcnn.onnx” 文件,并帮助用户配置好 ONNX Runtime 的 Python 环境,用几行代码就可以运行模型了。或者还有更简便的方法,我们可以利用 ONNX Runtime 编译出一个可以直接执行模型的应用程序。我们只需要给用户提供 ONNX 模型文件,并让用户在应用程序选择要执行的 ONNX 模型文件名就可以运行模型了。

总结

  • 模型部署,指把训练好的模型在特定环境中运行的过程。模型部署要解决模型框架兼容性差和模型运行速度慢这两大问题。
  • 模型部署的常见流水线是“深度学习框架-中间表示-推理引擎”。其中比较常用的一个中间表示是 ONNX。
  • 深度学习模型实际上就是一个计算图。模型部署时通常把模型转换成静态的计算图,即没有控制流(分支语句、循环语句)的计算图。
  • PyTorch 框架自带对 ONNX 的支持,只需要构造一组随机的输入,并对模型调用 torch.onnx.export 即可完成 PyTorch 到 ONNX 的转换。
  • 推理引擎 ONNX Runtime 对 ONNX 模型有原生的支持。给定一个 .onnx 文件,只需要简单使用 ONNX Runtime 的 Python API 就可以完成模型推理。

关于模型部署

当我们千辛万苦完成了前面的数据获取、数据清洗、模型训练、模型评估等等步骤之后,终于等到“上线”啦。想到辛苦训练出来的模型要被调用还有点小激动呢,可是真当下手的时候就有点懵了:模型要怎么部署?部署在哪里?有什么限制或要求?

模型训练重点关注的是如何通过训练策略来得到一个性能更好的模型,其过程似乎包含着各种“玄学”,被戏称为“炼丹”。整个流程包含从训练样本的获取(包括数据采集与标注),模型结构的确定,损失函数和评价指标的确定,到模型参数的训练,这部分更多是业务方去承接相关工作。一旦“炼丹”完成(即训练得到了一个指标不错的模型),如何将这颗“丹药”赋能到实际业务中,充分发挥其能力,这就是部署方需要承接的工作。

目前来说,我还没有真正的接触工业界的模型应用,仅仅只是在学术界进行模型训练和探索。部署只是简单的将模型推理代码直接部署在服务器上,因此也没有考虑过模型在工业界的部署,对于工业应用来说,模型的表现和推理速度同等重要,因此,如何提高模型的推理速度成为模型能否落地的关键因素。

部署流程大致分为以下几个步骤:模型转换、模型量化压缩、模型打包封装 SDK。这里我们主要探讨模型转换。

模型转换主要用于模型在不同框架之间的流转,常用于训练和推理场景的连接。目前主流的框架都以 ONNX 或者 caffe 为模型的交换格式,另外,根据需要,还可以在中间插入计算图优化,对计算机进行推理加速(诸如常见的 CONV/BN 的算子融合)。

模型部署

在软件工程中,部署指把开发完毕的软件投入使用的过程,包括环境配置、软件安装等步骤。类似地,对于深度学习模型来说,模型部署指让训练好的模型在特定环境中运行的过程。相比于软件部署,模型部署会面临更多的难题:

1)运行模型所需的环境难以配置。深度学习模型通常是由一些框架编写,比如 PyTorch、TensorFlow。由于框架规模、依赖环境的限制,这些框架不适合在手机、开发板等生产环境中安装。

2)深度学习模型的结构通常比较庞大,需要大量的算力才能满足实时运行的需求。模型的运行效率需要优化。

因为这些难题的存在,模型部署不能靠简单的环境配置与安装完成。经过工业界和学术界数年的探索,模型部署有了一条流行的流水线:

为了让模型最终能够部署到某一环境上,开发者们可以使用任意一种深度学习框架来定义网络结构,并通过训练确定网络中的参数。之后,模型的结构和参数会被转换成一种只描述网络结构的中间表示,一些针对网络结构的优化会在中间表示上进行。最后,用面向硬件的高性能编程框架(如 CUDA,OpenCL)编写,能高效执行深度学习网络中算子的推理引擎会把中间表示转换成特定的文件格式,并在对应硬件平台上高效运行模型。

这一条流水线解决了模型部署中的两大问题:使用对接深度学习框架和推理引擎的中间表示,开发者不必担心如何在新环境中运行各个复杂的框架;通过中间表示的网络结构优化和推理引擎对运算的底层优化,模型的运算效率大幅提升。

英伟达显卡监控工具nvtop & 深度学习cuda+pytorch安装教程

背景

在用英伟达显卡做深度学习训练或推理时,我们常用nvidia-smi指令来查看显卡的使用情况,如图所示

这种方法可以看出每张显卡内存和GPU利用率的实时情况,但看不出历史数据和变化曲线,这个时候就需要用到nvtop了。

Nvtop代表NVidia TOP,这是用于NVIDIA GPU的任务监视器。它可以处理多个GPU,并以熟悉的方式打印有关它们的信息。如图所示,很直观的显示了每张显卡的内存、GPU利用率曲线。本文对该工具的安装使用进行介绍。

1 安装方法

在Ubuntu disco (19.04) / Debian buster (stable)系统中,可以直接使用apt安装

sudo apt install nvtop

如果是在旧的系统,如ubuntu16.04等,则需要通过源码安装,方法如下

# 安装依赖sudo apt install cmake libncurses5-dev libncursesw5-dev git # 下载源码git clone https://github.com/Syllo/nvtop.gitmkdir -p nvtop/build && cd nvtop/buildcmake .. # 如果报错"Could NOT find NVML (missing: NVML_INCLUDE_DIRS)"# 则执行下边的语句,否则跳过cmake .. -DNVML_RETRIEVE_HEADER_ONLINE=True # 编译makesudo make install 

2 使用方法

安装完之后,可以执行nvtop -h来查看使用方法,介绍的很详细了,如果现实全部信息,直接nvtop就可以现实出我们上边的结果

nvtop version 1.0.0Available options:  -d --delay        : Select the refresh rate (1 == 0.1s)  -v --version      : Print the version and exit  -s --gpu-select   : Column separated list of GPU IDs to monitor  -i --gpu-ignore   : Column separated list of GPU IDs to ignore  -p --no-plot      : Disable bar plot  -C --no-color     : No colors  -N --no-cache     : Always query the system for user names and command line information  -f --freedom-unit : Use fahrenheit  -E --encode-hide  : Set encode/decode auto hide time in seconds (default 30s, negative = always on screen)  -h --help         : Print help and exit

CUDA+PYTORCH安装教程:

环境安装

albumentations 数据增强工具

简介 & 安装

albumentations 是一个给予 OpenCV的快速训练数据增强库,拥有非常简单且强大的可以用于多种任务(分割、检测)的接口,易于定制且添加其他框架非常方便。支持所有常见的计算机视觉任务,例如分类,语义分割,实例分割,对象检测和姿势估计。

albumentations包是一种针对数据增强专门写的API,里面基本包含大量的数据增强手段,其特点:

1、Albumentations支持所有常见的计算机视觉任务,如分类、语义分割、实例分割、目标检测和姿态估计

2、该库提供了一个简单统一的API,用于处理所有数据类型:图像(rbg图像、灰度图像、多光谱图像)、分割掩码、边界框和关键点。

3、该库包含70多种不同的增强功能,可以从现有数据中生成新的训练样本。

4、Albumentations快。我们对每个新版本进行基准测试,以确保增强功能提供最大的速度。

5、它与流行的深度学习框架(如PyTorch和TensorFlow)一起工作。顺便说一下,Albumentations是PyTorch生态系统的一部分。

6、由专家写的。作者既有生产计算机视觉系统的工作经验,也有参与竞争性机器学习的经验。许多核心团队成员是Kaggle Masters和Grandmasters。

7、该库广泛应用于工业、深度学习研究、机器学习竞赛和开源项目。

github及其示例地址如下:

可以通过 pip 的方式直接安装,也可以通过 pip + github 的方式,或者conda

  • pip 方式:pip install albumentations
  • pip + github:pip install -U git+https://github.com/albu/albumentations
  • conda方式,此方式需要先安装 imgaug,然后在安装 albumentations
conda install -c conda-forge imgaug
conda install albumentations -c albumentations
安装问题

在安装部分有个小问题,我的本机已经安装完opencv-python,然后我们再去安装albumentations的时候,出现了一个问题,就是我们的opencv-python阻止albumentations的安装,报错如下:# Could not install packages due to anEnvironmentError: [WinError 5] 拒绝访问,是因为在安装albumentations的时候还要安装opencv-python-headless,这个库和opencv冲突。

解决方式

对于这个问题的解决方式是我们在新的虚拟环境中,先安装albumentations,安装albumentations的时候,我们会伴随安装上一个opencv-python-headless,这个完全可以替代opencv-python的功能,所以我们就不用安装opencv了。

Spatial-level transforms(空间层次转换)

空间级转换将同时改变输入图像和附加目标,如掩模、边界框和关键点。下表显示了每个转换支持哪些附加目标。

Spatial-level transforms(空间层次转换) 空间级转换将同时改变输入图像和附加目标,如掩模、边界框和关键点。下表显示了每个转换支持哪些附加目标。

支持的列表

Blur
CLAHE
ChannelDropout
ChannelShuffle
ColorJitter
Downscale
Emboss
Equalize
FDA
FancyPCA
FromFloat
GaussNoise
GaussianBlur
GlassBlur
HistogramMatching
HueSaturationValue
ISONoise
ImageCompression
InvertImg
MedianBlur
MotionBlur
MultiplicativeNoise
Normalize
Posterize
RGBShift
RandomBrightnessContrast
RandomFog
RandomGamma
RandomRain
RandomShadow
RandomSnow
RandomSunFlare
RandomToneCurve
Sharpen
Solarize
Superpixels
ToFloat
ToGray
ToSepia

how to use:类似totch中的 transform 模块

import albumentations as A
import cv2
 
import matplotlib.pyplot as plt
 
# Declare an augmentation pipeline
transform = A.Compose([
    A.RandomCrop(width=512, height=512),
    A.HorizontalFlip(p=0.8),
    A.RandomBrightnessContrast(p=0.5),
])
 
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
# Augment an image
transformed = transform(image=image)
transformed_image = transformed["image"]
plt.imshow(transformed_image)
plt.show()

详细使用案例:

1、VerticalFlip 围绕X轴垂直翻转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.VerticalFlip(always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Blur后的图像')
plt.imshow(transformed_image)
plt.show()

2、Blur模糊输入图像

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Blur(blur_limit=15,always_apply=False, p=1)(image=image) 
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Blur后的图像')
plt.imshow(transformed_image)
plt.show()

3、HorizontalFlip 围绕y轴水平翻转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.HorizontalFlip(always_apply=False, p=1)(image=image) 
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('HorizontalFlip后的图像')
plt.imshow(transformed_image)
plt.show()

4、Flip水平,垂直或水平和垂直翻转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Flip(always_apply=False, p=1)(image=image) 
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Flip后的图像')
plt.imshow(transformed_image)
plt.show()

5、Transpose, 通过交换行和列来转置输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Transpose(always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Transpose后的图像')
plt.imshow(transformed_image)
plt.show()

6、RandomCrop 随机裁剪

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomCrop(512, 512,always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('RandomCrop后的图像')
plt.imshow(transformed_image)
plt.show()

7、RandomGamma 随机灰度系数

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomGamma(gamma_limit=(20, 20), eps=None, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('RandomGamma后的图像')
plt.imshow(transformed_image)
plt.show()

8、RandomRotate90 将输入随机旋转90度,N次

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomRotate90(always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('RandomRotate90后的图像')
plt.imshow(transformed_image)
plt.show()

10、ShiftScaleRotate 随机平移,缩放和旋转输入

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
#解决中文显示问题
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')   #第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('ShiftScaleRotate后的图像')
plt.imshow(transformed_image)
plt.show()

11、CenterCrop 裁剪图像的中心部分


# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.CenterCrop(256, 256, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("CenterCrop后的图像")
plt.imshow(transformed_image)
plt.show()

12、GridDistortion网格失真

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.GridDistortion(num_steps=10, distort_limit=0.3,border_mode=4, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("GridDistortion后的图像")
plt.imshow(transformed_image)
plt.show()

13、ElasticTransform 弹性变换

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.ElasticTransform(alpha=5, sigma=50, alpha_affine=50, interpolation=1, border_mode=4,always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("ElasticTransform后的图像")
plt.imshow(transformed_image)
plt.show()

14、RandomGridShuffle把图像切成网格单元随机排列

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomGridShuffle(grid=(3, 3), always_apply=False, p=1) (image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("RandomGridShuffle后的图像")
plt.imshow(transformed_image)
plt.show()

15、HueSaturationValue随机更改图像的颜色,饱和度和值

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("HueSaturationValue后的图像")
plt.imshow(transformed_image)
plt.show()

16、PadIfNeeded 填充图像

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.PadIfNeeded(min_height=2048, min_width=2048, border_mode=4, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("PadIfNeeded后的图像")
plt.imshow(transformed_image)
plt.show()

17、RGBShift,对图像RGB的每个通道随机移动值

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RGBShift(r_shift_limit=10, g_shift_limit=20, b_shift_limit=20, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("RGBShift后的图像")
plt.imshow(transformed_image)
plt.show()

18、GaussianBlur 使用随机核大小的高斯滤波器对图像进行模糊处理

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.GaussianBlur(blur_limit=11, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("GaussianBlur后的图像")
plt.imshow(transformed_image)
plt.show()

CLAHE自适应直方图均衡

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("CLAHE后的图像")
plt.imshow(transformed_image)
plt.show()

ChannelShuffle随机重新排列输入RGB图像的通道

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.ChannelShuffle(always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("ChannelShuffle后的图像")
plt.imshow(transformed_image)
plt.show()

InvertImg反色

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.InvertImg(always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("InvertImg后的图像")
plt.imshow(transformed_image)
plt.show()

Cutout 随机擦除

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.Cutout(num_holes=20, max_h_size=20, max_w_size=20, fill_value=0, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("Cutout后的图像")
plt.imshow(transformed_image)
plt.show()

RandomFog随机雾化

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.RandomFog(fog_coef_lower=0.3, fog_coef_upper=1, alpha_coef=0.08, always_apply=False, p=1)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("RandomFog后的图像")
plt.imshow(transformed_image)
plt.show()

GridDropout网格擦除

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
 
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread("aa.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Augment an image
transformed = A.GridDropout(ratio=0.5, unit_size_min=None, unit_size_max=None, holes_number_x=None, holes_number_y=None,
                            shift_x=0, shift_y=0, always_apply=False, p=0.5)(image=image)
transformed_image = transformed["image"]
plt.subplot(1, 2, 1)
plt.title('原图')  # 第一幅图片标题
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title("GridDropout后的图像")
plt.imshow(transformed_image)
plt.show()

组合变换(Compose)

变换不仅可以单独使用,还可以将这些组合起来,这就需要用到 Compose 类,该类继承自 BaseCompose。Compose 类含有以下参数:

  • transforms:转换类的数组,list类型
  • bbox_params:用于 bounding boxes 转换的参数,BboxPoarams 类型
  • keypoint_params:用于 keypoints 转换的参数, KeypointParams 类型
  • additional_targets:key新target 名字,value 为旧 target 名字的 dict,如 {‘image2’: ‘image’},dict 类型
  • p:使用这些变换的概率,默认值为 1.0
image3 = Compose([
        # 对比度受限直方图均衡
            #(Contrast Limited Adaptive Histogram Equalization)
        CLAHE(),
        # 随机旋转 90°
        RandomRotate90(),
        # 转置
        Transpose(),
        # 随机仿射变换
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
        # 模糊
        Blur(blur_limit=3),
        # 光学畸变
        OpticalDistortion(),
        # 网格畸变
        GridDistortion(),
        # 随机改变图片的 HUE、饱和度和值
        HueSaturationValue()
    ], p=1.0)(image=image)['image']

随机选择(OneOf)

它同Compose一样,都是做组合的,都有概率。区别就在于:Compose组合下的变换是要挨着顺序做的,而OneOf组合里面的变换是系统自动选择其中一个来做,而这里的概率参数p是指选定后的变换被做的概率。例:

image4 = Compose([
        RandomRotate90(),
        # 翻转
        Flip(),
        Transpose(),
        OneOf([
            # 高斯噪点
            IAAAdditiveGaussianNoise(),
            GaussNoise(),
        ], p=0.2),
        OneOf([
            # 模糊相关操作
            MotionBlur(p=.2),
            MedianBlur(blur_limit=3, p=0.1),
            Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        OneOf([
            # 畸变相关操作
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.2),
        OneOf([
            # 锐化、浮雕等操作
            CLAHE(clip_limit=2),
            IAASharpen(),
            IAAEmboss(),
            RandomBrightnessContrast(),            
        ], p=0.3),
        HueSaturationValue(p=0.3),
    ], p=1.0)(image=image)['image']

在程序中的使用

def get_transform(phase: str):
    if phase == 'train':
        return Compose([
            A.RandomResizedCrop(height=CFG.img_size, width=CFG.img_size),
            A.Flip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.HueSaturationValue(p=0.5),
            A.OneOf([
                A.RandomBrightnessContrast(p=0.5),
                A.RandomGamma(p=0.5),
            ], p=0.5),
            A.OneOf([
                A.Blur(p=0.1),
                A.GaussianBlur(p=0.1),
                A.MotionBlur(p=0.1),
            ], p=0.1),
            A.OneOf([
                A.GaussNoise(p=0.1),
                A.ISONoise(p=0.1),
                A.GridDropout(ratio=0.5, p=0.2),
                A.CoarseDropout(max_holes=16, min_holes=8, max_height=16, max_width=16, min_height=8, min_width=8, p=0.2)
            ], p=0.2),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    else:
        return Compose([
            A.Resize(height=CFG.img_size, width=CFG.img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

分类问题中的使用

在 albumentations 中可以用于分类问题中的操作包括:

HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, RandomBrightnessContrast, IAAPiecewiseAffine, IAASharpen, IAAEmboss, Flip, OneOf, Compose

分割问题中的使用

在此示例中,需要使用到如下的类:

PadIfNeeded, HorizontalFlip, VerticalFlip, CenterCrop, Crop, Compose, Transpose, RandomRotate90, ElasticTransform, GridDistortion, OpticalDistortion, RandomSizedCrop, OneOf, CLAHE, RandomBrightnessContrast, RandomGamma

填充(Padding)

在 Unet 这样的网络架构中,输入图片的尺寸需要尺寸需要能被 $2^N$ 整除,其中 $N$ 是池化层(maxpooling)的层数。在最简单的 Unet 结构中 $N$ 的值为 5,那么我们就需要将输入的图片填充到能被 $2^5=32$ 除尽的数字,应该上面图片的大小为 101,因此最接近的大小为 128。要进行此操作就需要用到 PadIfNeeded 类,其含有如下参数:

  • min_height:最终图片的最小高度,int 类型
  • min_width:最终图片的最小宽度,int 类型
  • border_mode:OpenCV 边界模式,默认值为 cv2.BORDER_REFLECT_101
  • value:如果border_mode 值为 cv2.BORDER_CONSTANT 时的填充值,int、float或者 int、float数组类型
  • mask_value:如果border_mode 值为 cv2.BORDER_CONSTANT 时 mask 的填充值,int、float或者 int、float数组类型
  • p:进行此转换的概率,默认值为 1.0

默认条件下 PadIfNeeded 会对图片和mask的四条边都进行填充,填充的类型包括零填充(zero)、常量填充(constant)和反射填充(reflection),默认为反射填充。使用方法如下:

image1 = PadIfNeeded(p=1, min_height=128, min_width=128)(image=image, mask=mask)
image11_padded = image1['image']
mask11_padded = image1['mask']
# (128, 128, 3) (128, 128)
print(image11_padded.shape, mask11_padded.shape)

裁剪与中心裁剪(Crop & CenterCrop)

上面我们使用了 PadIfNeeded 对图片进行了填充,想要恢复原始的大小这时候就可以使用相关的裁剪方法:CenterCropCrop 等类。

先来看 CenterCrop 的使用,它主要从输入的图片中间进行裁剪,主要含有以下参数:

  • height:裁剪的高度,int 类型
  • width:裁剪的宽度,int 类型
  • p:使用此转换方法的概率,默认值为 1.0

原始的图片和mask大小为 101,因此此处设置需要裁剪的宽高(original_height/original_width)为 101,使用方法如下:

image2 = CenterCrop(p=1.0, height=original_height, 
                    width=original_width)(image=image11_padded, mask=mask11_padded)

image22_center_cropped = image2['image']
mask22_center_cropped = image2['mask']
# (101, 101, 3) (101, 101)
print(image22_center_cropped.shape, mask22_center_cropped.shape)

非破坏性转换

从上面的转换操作中可以看到操作破坏了图像的空间信息,对于想卫星、航空或者医学图片我们并不希望破坏它原有的空间结构,如以下的八种操作就不会破坏原有图片的空间结构。

通过 HorizontalFlipVerticalFlipTransposeRandomRotate90 四种操作的组合就可以得到上面的八种操作。这些操作可以参考上面《分类问题中的使用》章节。

非刚体转换

在医学影像问题中非刚体装换可以帮助增强数据。albumentations 中主要提供了以下几种非刚体变换类:ElasticTransformGridDistortion 和 OpticalDistortion。三个类的主要参数如下:

ElasticTransform 类参数:

  • alphasigma:高斯过滤参数,float类型
  • alpha_affine:范围为 (-alpha_affine, alpha_affine),float 类型
  • interpolationborder_modevaluemask_value:与其他类含义一样
  • approximate:是否应平滑具有固定大小核的替换映射(displacement map),若启用此选项,在大图上会有两倍的速度提升,boolean类型。
  • p:使用此转换的概率,默认值为 0.5

GridDistortion 类参数:

  • num_steps:在每一条边上网格单元的数量,默认值为 5,int 类型
  • distort_limit:如果是单值,那么会被转成 (-distort_limit, distort_limit),默认值为 (-0.03, 0.03),float或float数组类型
  • interpolationborder_modevaluemask_value:与其他类含义一样
  • p:使用此转换的概率,默认值为 0.5

OpticalDistortion 类参数:

  • distort_limit:如果是单值,那么会被转成 (-distort_limit, distort_limit),默认值为 (-0.05, 0.05),float或float数组类型
  • shift_limit:如果是单值,那么会被转成 (-shift_limit, shift_limit),默认值为 (-0.05, 0.05),float或float数组类型
  • interpolationborder_modevaluemask_value:与其他类含义一样
  • p:使用此转换的概率,默认值为 0.5

使用方式如下:

# 弹性装换
image41 = ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, 
                          alpha_affine=120 * 0.03)(image=image, mask=mask)
image_elastic = image41['image']
mask_elastic = image41['mask']

# 网格畸变
image42 = GridDistortion(p=1, num_steps=10)(image=image, mask=mask)

image_grid = image42['image']
mask_grid = image42['mask']

# 光学畸变
image43 = OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)(image=image, mask=mask)

image_optical = image43['image']
mask_optical = image43['mask']

效果如下:

组合多种转换

我们可以将上面的填充、裁剪、非刚体转换、非破坏性转换组合起来:

image5 = Compose([
   # 非刚体转换
    OneOf([RandomSizedCrop(min_max_height=(50, 101), 
                           height=original_height, width=original_width, p=0.5),
          PadIfNeeded(min_height=original_height, 
                      min_width=original_width, p=0.5)], p=1),
    # 非破坏性转换
    VerticalFlip(p=0.5),              
    RandomRotate90(p=0.5),
    # 非刚体转换
    OneOf([
        ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        GridDistortion(p=0.5),
        OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)                  
        ], p=0.8),
    # 非空间性转换
    CLAHE(p=0.8),
    RandomBrightnessContrast(p=0.8),    
    RandomGamma(p=0.8)])(image=image, mask=mask)

image_heavy = image5['image']
mask_heavy = image5['mask']

运行的效果如下:

python 深浅拷贝

拷贝是Python学习过程中很容易被忽略,但是在项目开发过程中起着重要作用的一个概念。

有很多开发者由于忽视这一点,甚至导致项目中出现很严重的BUG。

我之前就因为这样的一个小问题,一不小心掉坑里了。反复定位才发现竟然是由这个容易被忽视的问题引起的….

在这篇文章中,我们将看看如何在Python中深度和浅度拷贝对象,深入探讨Python 如何处理对象引用和内存中的对象。

浅拷贝

当我们在 Python 中使用赋值语句 (=) 来创建复合对象的副本时,例如,列表或类实例或基本上任何包含其他对象的对象,Python 并没有克隆对象本身。

相反,它只是将引用绑定到目标对象上。

想象一下,我们有一个列表,里面有以下元素。

original_list =[[1,2,3], [4,5,6], ["X", "Y", "Z"]]

如果我们尝试使用如下的赋值语句来复制我们的原始列表。

shallow_copy_list = original_list
print(shallow_copy_list)

它可能看起来像我们克隆了我们的对象,或许很多同学会认为生成了两个对象,

[[1,2,3], [4,5,6], ['X', 'Y', 'Z']]

但是,我们真的有两个对象吗?

不,并没有。我们有两个引用变量,指向内存中的同一个对象。通过打印这两个对象在内存中的ID,可以很容易地验证这一点。

id(original_list) # 4517445712
id(shallow_copy_list) # 4517445712

一个更具体的证明可以通过尝试改变 “两个列表”中的一个值来观察–而实际上,我们改变的是同一个列表,两个指针指向内存中的同一个对象。

让我们来改变original_list所指向的对象的最后一个元素。

# Last element of last element
original_list[-1][-1] = "ZZZ"
print(original_list)

输出结果是:

[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]

两个引用变量都指向同一个对象,打印shallow_copy_list将返回相同的结果。

print(shallow_copy_list) # [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]

浅层复制是指复制一个对象的引用并将其存储在一个新的变量中的过程。original_list和shallow_copy_list只是指向内存(RAM)中相同地址的引用,这些引用存储了[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']的值。

我们在复制过程中,并没有生成一个新的对象,试想一下,如果不理解这一点,很多同学会误认为它生成了一个完全独立的新对象,殊不知,在对这个新变量shallow_copy_list进行操作时,原来的变量original_list也会跟随改变。

除了赋值语句之外,还可以通过Python标准库的拷贝模块实现浅拷贝

要使用拷贝模块,我们必须首先导入它。

import copy
second_shallow_copy_list = copy.copy(original_list)

把它们都打印出来,看看它们是否引用了相同的值。

print(original_list)
print(second_shallow_copy_list)

不出所料,确实如此,

[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]
[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'ZZZ']]

通常,你想复制一个复合对象,例如在一个方法的开始,然后修改克隆的对象,但保持原始对象的原样,以便以后再使用它。

为了达到这个目的,我们需要对该对象进行深度复制。现在让我们来学习一下什么是深度拷贝以及如何深度拷贝一个复合对象。

深拷贝

深度复制一个对象意味着真正地将该对象和它的值克隆到内存中的一个新的副本(实例)中,并具有这些相同的值。

通过深度拷贝,我们实际上可以创建一个独立于原始数据的新对象,但包含相同的值,而不是为相同的值创建新的引用。

在一个典型的深度拷贝过程中,首先,一个新的对象引用被创建,然后所有的子对象被递归地加入到父对象中。

这样一来,与浅层拷贝不同,对原始对象的任何修改都不会反映在拷贝对象中(反之亦然)。

下面是一个典型的深度拷贝的简单图示。

要在 Python 中深度拷贝一个对象,我们使用 copy 模块的 deepcopy()方法。

让我们导入 copy 模块并创建一个列表的深度拷贝。

import copy
 
original_list = [[1,2,3], [4,5,6], ["X", "Y", "Z"]]
deepcopy_list = copy.deepcopy(original_list)

现在让我们打印我们的列表,以确保输出是相同的,以及他们的ID是唯一的。

print(id(original_list), original_list)
print(id(deepcopy_list), deepcopy_list)

输出结果证实,我们已经为自己创建了一个真正的副本。

4517599280, [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'Z']]
4517599424, [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'Z']]

现在让我们试着修改我们的原始列表,把最后一个列表的最后一个元素改为 “O”,然后打印出来看看结果。

original_list[-1][-1] = "O"
print(original_list)

我们得到了预期的结果。

[[1, 2, 3], [4, 5, 6], ['X', 'Y', 'O']]

现在,如果我们继续前进并尝试打印我们的副本列表,之前的修改并没有影响新的变量。

print(deepcopy_list) # [[1, 2, 3], [4, 5, 6], ['X', 'Y', 'Z']]

记住,copy()deepcopy()方法适用于其他复合对象。这意味着,你也可以用它们来创建类实例的副本。

pytorch如何加载不同尺寸的图片数据

如何使用dataloader加载相同维度但是不同尺寸的数据集(图片),不使用resize,crop等改变模型输入的shape。

知乎:https://www.zhihu.com/question/395888465

如果加载的数据的维度尺寸不相同的话,在迭代器中会爆出如下的错误

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0.

1、pytorch的dataloader默认的collate_fn会使用torch.stack合并多张图片成为batch

要么另外写一个collate_fn

要么在dataset类中对图片做padding,使得图片的size一样,可以直接stack

2、关于collate_fn:

 https://pytorch.org/docs/stable/data.html#working-with-collate-fn

The use of collate_fn is slightly different when automatic batching is enabled or disabled.

  • When automatic batching is disabledcollate_fn is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default collate_fn simply converts NumPy arrays in PyTorch tensors.
  • When automatic batching is enabledcollate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes behavior of the default collate_fn in this case.

可以看到,你可以考虑关闭自动打包,这样collate_fn处理的就是独立的样本。也可以打开自动打包,这样这个函数就会被输入一个batch列表的数据。注意,这个列表的数据可以不同大小哦,知识这样你就没办法将其stack成一个完整的batch。所以,实际上你的报错,应该是这个位置出的问题。

所以可以考虑以下几种策略:

  • 单个样本输入,这样同一个batch组合的时候就不需要担心了
  • 对输入样本padding成最大的形状,组合成batch,之后送入网络的时候,你可以把数据拆分开,按你想要的将其去掉padding或者其他操作
  • 正常读取,之后再自定义的collate_fn中将数据拆开返回,这样可以返回相同结构的数据

对于最后一点,给个小demo:

class OurDataset(Dataset):
    def __init__(self, *tensors):
        self.tensors = tensors
    def __getitem__(self, index):
        return self.tensors[index]
    def __len__(self):
        return len(self.tensors)

def collate_wrapper(batch):
#函数就会输入一个batch的列表的数据(注意是batch是一个列表,所以里面的数据可以不同大小)
    a, b = batch
    return a, b

a = torch.randn(3, 2, 3)
b = torch.randn(3, 3, 4)
dataset = OurDataset(a, b)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper)

for sample in loader:
    print([x.size() for x in sample])

# Out: [torch.Size([1, 3, 2, 3]), torch.Size([1, 3, 3, 4])]

torch.roll 函数

The Question about the mask of window attention:

https://github.com/microsoft/Swin-Transformer/issues/38

torch.roll(inputshiftsdims=None) → Tensor

Roll the tensor input along the given dimension(s). Elements that are shifted beyond the last position are re-introduced at the first position. If dims is None, the tensor will be flattened before rolling and then restored to the original shape.Parameters

  • input (Tensor) – the input tensor.
  • shifts (int or tuple of python:ints) – The number of places by which the elements of the tensor are shifted. If shifts is a tuple, dims must be a tuple of the same size, and each dimension will be rolled by the corresponding value
  • dims (int or tuple of python:ints) – Axis along which to roll

沿给定维数滚动张量,移动到最后一个位置以外的元素将在第一个位置重新引入。如果没有指定尺寸,张量将在轧制前被压平,然后恢复到原始形状。

简单理解:shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面

  • input (Tensor) —— 输入张量。
  • shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
  • dims (int 或 tuple of python:int) 确定的维度。

Example:

>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
>>> torch.roll(x, 1)
tensor([[8, 1],
        [2, 3],
        [4, 5],
        [6, 7]])

'''第0维度向下移1位,多出的[7,8]补充到顶部'''
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
        [1, 2],
        [3, 4],
        [5, 6]])

'''第0维度向上移1位,多出的[1,2]补充到底部'''
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
        [5, 6],
        [7, 8],
        [1, 2]])

'''tuple元祖,维度一一对应:
第0维度向下移2位,多出的[5,6][7,8]补充到顶部,
第1维向右移1位,多出的[6,8,2,4]补充到最左边'''
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
        [8, 7],
        [2, 1],
        [4, 3]])

空洞卷积

Multi-Scale Context Aggregation by Dilated Convolutions

一个简单的例子,[动态图来源:vdumoulin/conv_arithmetic]:

动图
Standard Convolution with a 3 x 3 kernel (and padding)
动图封面
Dilated Convolution with a 3 x 3 kernel and dilation rate 2

对于 dilated convolution, 我们已经可以发现他的优点,即内部数据结构的保留和避免使用 down-sampling 这样的特性。但是完全基于 dilated convolution 的结构如何设计则是一个新的问题。

潜在问题 1:The Gridding Effect

假设我们仅仅多次叠加 dilation rate 2 的 3 x 3 kernel 的话,则会出现这个问题:

我们发现我们的 kernel 并不连续,也就是并不是所有的 pixel 都用来计算了,因此这里将信息看做 checker-board 的方式会损失信息的连续性。这对 pixel-level dense prediction 的任务来说是致命的。

潜在问题 2:Long-ranged information might be not relevant.

我们从 dilated convolution 的设计背景来看就能推测出这样的设计是用来获取 long-ranged information。然而光采用大 dilation rate 的信息或许只对一些大物体分割有效果,而对小物体来说可能则有弊无利了。如何同时处理不同大小的物体的关系,则是设计好 dilated convolution 网络的关键。

通向标准化设计:Hybrid Dilated Convolution (HDC)

对于上个 section 里提到的几个问题,图森组的文章对其提出了较好的解决的方法。他们设计了一个称之为 HDC 的设计结构。

第一个特性是,叠加卷积的 dilation rate 不能有大于1的公约数。比如 [2, 4, 6] 则不是一个好的三层卷积,依然会出现 gridding effect。

第二个特性是,我们将 dilation rate 设计成 锯齿状结构,例如 [1, 2, 5, 1, 2, 5] 循环结构。

第三个特性是,我们需要满足一下这个式子: Mi=max[Mi+1−2ri,Mi+1−2(Mi+1−ri),ri]

其中 ri 是 i 层的 dilation rate 而 Mi 是指在 i 层的最大dilation rate,那么假设总共有n层的话,默认 Mn=rn 。假设我们应用于 kernel 为 k x k 的话,我们的目标则是 M2≤k ,这样我们至少可以用 dilation rate 1 即 standard convolution 的方式来覆盖掉所有洞。

一个简单的例子: dilation rate [1, 2, 5] with 3 x 3 kernel (可行的方案)

而这样的锯齿状本身的性质就比较好的来同时满足小物体大物体的分割要求(小 dilation rate 来关心近距离信息,大 dilation rate 来关心远距离信息)。

这样我们的卷积依然是连续的也就依然能满足VGG组观察的结论,大卷积是由小卷积的 regularisation 的 叠加。

代码:(绘制空洞卷积)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def dilated_conv_one_pixel(center: (int, int),feature_map: np.ndarray,k: int = 3,r: int = 1,v: int = 1):
    """
    膨胀卷积核中心在指定坐标center处时,统计哪些像素被利用到,
    并在利用到的像素位置处加上增量v
    Args:
        center: 膨胀卷积核中心的坐标
        feature_map: 记录每个像素使用次数的特征图
        k: 膨胀卷积核的kernel大小
        r: 膨胀卷积的dilation rate
        v: 使用次数增量
    """
    assert divmod(3, 2)[1] == 1

    # left-top: (x, y)
    left_top = (center[0] - ((k - 1) // 2) * r, center[1] - ((k - 1) // 2) * r)
    for i in range(k):
        for j in range(k):
            feature_map[left_top[1] + i * r][left_top[0] + j * r] += v


def dilated_conv_all_map(dilated_map: np.ndarray,
                         k: int = 3,
                         r: int = 1):
    """
    根据输出特征矩阵中哪些像素被使用以及使用次数,
    配合膨胀卷积k和r计算输入特征矩阵哪些像素被使用以及使用次数
    Args:
        dilated_map: 记录输出特征矩阵中每个像素被使用次数的特征图
        k: 膨胀卷积核的kernel大小
        r: 膨胀卷积的dilation rate
    """
    new_map = np.zeros_like(dilated_map)
    for i in range(dilated_map.shape[0]):
        for j in range(dilated_map.shape[1]):
            if dilated_map[i][j] > 0:
                dilated_conv_one_pixel((j, i), new_map, k=k, r=r, v=dilated_map[i][j])

    return new_map


def plot_map(matrix: np.ndarray):
    plt.figure()

    c_list = ['white', 'blue', 'red']
    new_cmp = LinearSegmentedColormap.from_list('chaos', c_list)
    plt.imshow(matrix, cmap=new_cmp)

    ax = plt.gca()
    ax.set_xticks(np.arange(-0.5, matrix.shape[1], 1), minor=True)
    ax.set_yticks(np.arange(-0.5, matrix.shape[0], 1), minor=True)

    # 显示color bar
    plt.colorbar()

    # 在图中标注数量
    thresh = 5
    for x in range(matrix.shape[1]):
        for y in range(matrix.shape[0]):
            # 注意这里的matrix[y, x]不是matrix[x, y]
            info = int(matrix[y, x])
            ax.text(x, y, info,
                    verticalalignment='center',
                    horizontalalignment='center',
                    color="white" if info > thresh else "black")
    ax.grid(which='minor', color='black', linestyle='-', linewidth=1.5)
    plt.show()
    plt.close()


def main():
    # bottom to top
    dilated_rates = [1, 2, 3]
    # init feature map
    size = 31
    m = np.zeros(shape=(size, size), dtype=np.int32)
    center = size // 2
    m[center][center] = 1
    # print(m)
    # plot_map(m)

    for index, dilated_r in enumerate(dilated_rates[::-1]):
        new_map = dilated_conv_all_map(m, r=dilated_r)
        m = new_map
    print(m)
    plot_map(m)

绘制结果:

torch grid_sample() 函数

grid_sample底层是应用双线性插值,把输入的tensor转换为指定大小。那它和interpolate有啥区别呢?
interpolate是规则采样(uniform),但是grid_sample的转换方式,内部采点的方式并不是规则的,是一种更为灵活的方式。可以认为采样点根据 grid 矩阵来决定。

Pytorch中grid_sample函数的接口声明如下:

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)

在官方文档里面关于该函数的作用是这样描述的:

Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.

简单来说就是,提供一个input的Tensor以及一个对应的flow-field网格(比如光流,体素流等),然后根据网格(grid)中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。

关于input、grid以及output的尺寸如下所示:(input也可以是5D的Tensor,这里我们只考虑4D的情况) 注意output的尺寸可以大于input,所以 grid_sample 可以用来上采样

input:(N,C,Hin,Win)
grid:(N,Hout,Wout​,2)
output:(N,C,Hout,Wout​)

这里的input和output就是输入的图片,或者是网络中的feature map。关键的处理过程在于grid,grid的最后一维的大小为2,即表示input中pixel的位置信息 (x,y) ,这里一般会将x和y的取值范围归一化到 [−1,1] 之间, (−1,−1) 表示input左上角的像素的坐标,(1,1) 表示input右下角的像素的坐标,对于超出这个范围的坐标(x,y),函数将会根据参数padding_mode的设定进行不同的处理。

  • padding_mode=’zeros’:对于越界的位置在网格中采用pixel value=0进行填充。
  • padding_mode=’border’:对于越界的位置在网格中采用边界的pixel value进行填充。
  • padding_mode=’reflection’:对于越界的位置在网格中采用关于边界的对称值进行填充。

对于mode=’bilinear’参数,则定义了在input中指定位置的pixel value中进行插值的方法,为什么需要插值呢?因为前面我们说了,grid中表示的位置信息x和y的取值范围在 [−1,1] 之间,这就意味着我们要根据一个浮点型的坐标值在input中对pixel value进行采样,mode有'bilinear' | 'nearest' | 'bicubic'(双三次插值)三种模式。 nearest就是直接采用与 (x,y) 距离最近处的像素值来填充grid,而bilinear则是采用双线性插值的方法来进行填充,mode=’bicubic’仅支持四维输入,总之其与nearest的区别就是nearest只考虑最近点的pixel value,而bilinear则采用(x,y)周围的四个pixel value进行加权平均值来填充grid。

双线性插值:双线性插值是用原图像中4(2*2)个点计算新图像中1个点

双三次插值(Bicubic interpolation):双三次插值是用原图像中16(4*4)个点计算新图像中1个点,效果比较好,但是计算代价过大

上面讲到双线性插值会对 (x,y) 周围的四个pixel value进行加权平均,那么每个位置的权重是多少呢?可以简单参考下图中双线性插值的例子:

其双线性插值的结果为:

采用下图我们可以对双线性插值有个更为直观的认识:

从上图中可以看到双线性插值就是首先在平面 zoy 内,对 f(x0,y0) 和 f(x0,y1) 进行插值得到 z1 ,对 f(x1,y0) 和 f(x1,y1) 进行插值得到 z2 ,随后在平面 zox 内进行插值得到最终的 z 点的值就是最终所求的结果,这里的平面内插值其实就是采用我们高中学的,直线的两点式求出直线表达是,再带入自变量(x或y)的坐标得到插值的结果。联立两次直线的两点式就能得到双线性插值的结果,说到这里“双线性”也顾名思义了。

下面给出正式的推导:

已知四点的坐标如下所示:

Q11=(x1,y1,f(x1,y1)), Q21=(x2,y1,f(x2,y2)), Q12=(x1,y2,f(x3,y3)), Q22=(x2,y2,f(x4,y4))

其中有z=f(x,y):

先在x方向上进行插值有:

以上式子便是最终双线性插值的最终表达式,由于4个点的权重部分中分母是相同的可以忽略不计,现在再回去看上面的例子是不是就一目了然了。

例子:

import torch
from torch.nn import functional as F

inp = torch.ones(1, 1, 4, 4)
print(inp)
# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
 # grid的生成方式等价于用mesh_grid
new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)
grid = grid.unsqueeze(0) #返回一个新的张量,对输入的既定位置插入维度 1
print(grid.shape)
outp = F.grid_sample(inp, grid=grid, mode='bilinear')
print(outp.shape)  #torch.Size([1, 1, 20, 20])

在上面的例子中,我们将一个大小为4×4的tensor 转换为了一个20×20的。grid的大小指定了输出大小,每个grid的位置是一个(x,y)坐标,其值来自于:输入input的(x,y)中 的四邻域插值得到的。

在这里插入图片描述
图片来自于SFnet(eccv2020)。flow field是grid, low_resolution是input, high resolution是output。

python 类型注释 # type

Type Comments[类型注解]

注释是在Python 3中引入的,并且它们没有被反向移植到Python 2.这意味着如果您正在编写需要支持旧版Python的代码,则无法使用注释。

要向函数添加类型注释,您可以执行以下操作:

import math 
def circumference(radius):    
# type: (float) -> float    
   return 2 * math.pi * radius

类型注释只是注释,所以它们可以用在任何版本的Python中。

类型注释由类型检查器直接处理,所以不存在__annotations__字典对象中:

>>> circumference.__annotations__{}

类型注释必须以type: 字面量开头,并与函数定义位于同一行或下一行。如果您想用几个参数来注释一个函数,您可以用逗号分隔每个类型:

def headline(text, width=80, fill_char="-"):  
  # type: (str, int, str) -> str    
   return f" {text.title()} ".center(width, fill_char) 

print(headline("type comments work", width=40))

您还可以使用自己的注释在单独的行上编写每个参数:

# headlines.py
 
  def headline(
      text,           # type: str
      width=80,       # type: int
      fill_char="-",  # type: str
  ):                  # type: (...) -> str
      return f" {text.title()} ".center(width, fill_char)
 
 print(headline("type comments work", width=40))

通过Python和Mypy运行示例:

$  python headlines.py
---------- Type Comments Work ---------- 
$ mypy headline.py
$

如果传入一个字符串width=”full”,再次运行mypy会出现一下错误。

$ mypy headline.py
headline.py:10: error: Argument "width" to "headline" has incompatible
                       type "str"; expected "int"

您还可以向变量添加类型注释。这与您向参数添加类型注释的方式类似:

pi = 3.142  # type: float

上面的例子可以检测出pi是float类型。