最近看到很多论文里都有这个函数(yolov3 以及最近大火的swin transformer),记录下函数的使用:
https://pytorch.org/docs/stable/generated/torch.meshgrid.html
说明:
torch.meshgrid()的功能是生成网格,可以用于生成坐标。
函数输入:
输入两个数据类型相同的一维tensor
函数输出:
输出两个tensor(tensor行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数)
注意:
1)当两个输入tensor数据类型不同或维度不是一维时会报错。
2)其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素各列元素相同。
>>> x = torch.tensor([1, 2, 3]) >>> y = torch.tensor([4, 5, 6]) Observe the element-wise pairings across the grid, (1, 4), (1, 5), ..., (3, 6). This is the same thing as the cartesian product. >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij') >>> grid_x tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) >>> grid_y tensor([[4, 5, 6], [4, 5, 6], [4, 5, 6]])
# 【1】
import torch
a = torch.tensor([1, 2, 3, 4])
print(a)
b = torch.tensor([4, 5, 6])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
结果显示:
tensor([1, 2, 3, 4])
tensor([4, 5, 6])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]])
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
# 【2】
import torch
a = torch.tensor([1, 2, 3, 4, 5, 6])
print(a)
b = torch.tensor([7, 8, 9, 10])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
结果显示:
tensor([1, 2, 3, 4, 5, 6])
tensor([ 7, 8, 9, 10])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5],
[6, 6, 6, 6]])
tensor([[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10],
[ 7, 8, 9, 10]])