torch grid_sample() 函数

grid_sample底层是应用双线性插值,把输入的tensor转换为指定大小。那它和interpolate有啥区别呢?
interpolate是规则采样(uniform),但是grid_sample的转换方式,内部采点的方式并不是规则的,是一种更为灵活的方式。可以认为采样点根据 grid 矩阵来决定。

Pytorch中grid_sample函数的接口声明如下:

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)

在官方文档里面关于该函数的作用是这样描述的:

Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.

简单来说就是,提供一个input的Tensor以及一个对应的flow-field网格(比如光流,体素流等),然后根据网格(grid)中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。

关于input、grid以及output的尺寸如下所示:(input也可以是5D的Tensor,这里我们只考虑4D的情况) 注意output的尺寸可以大于input,所以 grid_sample 可以用来上采样

input:(N,C,Hin,Win)
grid:(N,Hout,Wout​,2)
output:(N,C,Hout,Wout​)

这里的input和output就是输入的图片,或者是网络中的feature map。关键的处理过程在于grid,grid的最后一维的大小为2,即表示input中pixel的位置信息 (x,y) ,这里一般会将x和y的取值范围归一化到 [−1,1] 之间, (−1,−1) 表示input左上角的像素的坐标,(1,1) 表示input右下角的像素的坐标,对于超出这个范围的坐标(x,y),函数将会根据参数padding_mode的设定进行不同的处理。

  • padding_mode=’zeros’:对于越界的位置在网格中采用pixel value=0进行填充。
  • padding_mode=’border’:对于越界的位置在网格中采用边界的pixel value进行填充。
  • padding_mode=’reflection’:对于越界的位置在网格中采用关于边界的对称值进行填充。

对于mode=’bilinear’参数,则定义了在input中指定位置的pixel value中进行插值的方法,为什么需要插值呢?因为前面我们说了,grid中表示的位置信息x和y的取值范围在 [−1,1] 之间,这就意味着我们要根据一个浮点型的坐标值在input中对pixel value进行采样,mode有'bilinear' | 'nearest' | 'bicubic'(双三次插值)三种模式。 nearest就是直接采用与 (x,y) 距离最近处的像素值来填充grid,而bilinear则是采用双线性插值的方法来进行填充,mode=’bicubic’仅支持四维输入,总之其与nearest的区别就是nearest只考虑最近点的pixel value,而bilinear则采用(x,y)周围的四个pixel value进行加权平均值来填充grid。

双线性插值:双线性插值是用原图像中4(2*2)个点计算新图像中1个点

双三次插值(Bicubic interpolation):双三次插值是用原图像中16(4*4)个点计算新图像中1个点,效果比较好,但是计算代价过大

上面讲到双线性插值会对 (x,y) 周围的四个pixel value进行加权平均,那么每个位置的权重是多少呢?可以简单参考下图中双线性插值的例子:

其双线性插值的结果为:

采用下图我们可以对双线性插值有个更为直观的认识:

从上图中可以看到双线性插值就是首先在平面 zoy 内,对 f(x0,y0) 和 f(x0,y1) 进行插值得到 z1 ,对 f(x1,y0) 和 f(x1,y1) 进行插值得到 z2 ,随后在平面 zox 内进行插值得到最终的 z 点的值就是最终所求的结果,这里的平面内插值其实就是采用我们高中学的,直线的两点式求出直线表达是,再带入自变量(x或y)的坐标得到插值的结果。联立两次直线的两点式就能得到双线性插值的结果,说到这里“双线性”也顾名思义了。

下面给出正式的推导:

已知四点的坐标如下所示:

Q11=(x1,y1,f(x1,y1)), Q21=(x2,y1,f(x2,y2)), Q12=(x1,y2,f(x3,y3)), Q22=(x2,y2,f(x4,y4))

其中有z=f(x,y):

先在x方向上进行插值有:

以上式子便是最终双线性插值的最终表达式,由于4个点的权重部分中分母是相同的可以忽略不计,现在再回去看上面的例子是不是就一目了然了。

例子:

import torch
from torch.nn import functional as F

inp = torch.ones(1, 1, 4, 4)
print(inp)
# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
 # grid的生成方式等价于用mesh_grid
new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)
grid = grid.unsqueeze(0) #返回一个新的张量,对输入的既定位置插入维度 1
print(grid.shape)
outp = F.grid_sample(inp, grid=grid, mode='bilinear')
print(outp.shape)  #torch.Size([1, 1, 20, 20])

在上面的例子中,我们将一个大小为4×4的tensor 转换为了一个20×20的。grid的大小指定了输出大小,每个grid的位置是一个(x,y)坐标,其值来自于:输入input的(x,y)中 的四邻域插值得到的。

在这里插入图片描述
图片来自于SFnet(eccv2020)。flow field是grid, low_resolution是input, high resolution是output。

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注