The Question about the mask of window attention:
https://github.com/microsoft/Swin-Transformer/issues/38
torch.roll
(input, shifts, dims=None) → Tensor
Roll the tensor input
along the given dimension(s). Elements that are shifted beyond the last position are re-introduced at the first position. If dims
is None, the tensor will be flattened before rolling and then restored to the original shape.Parameters
- input (Tensor) – the input tensor.
- shifts (int or tuple of python:ints) – The number of places by which the elements of the tensor are shifted. If shifts is a tuple, dims must be a tuple of the same size, and each dimension will be rolled by the corresponding value
- dims (int or tuple of python:ints) – Axis along which to roll
沿给定维数滚动张量,移动到最后一个位置以外的元素将在第一个位置重新引入。如果没有指定尺寸,张量将在轧制前被压平,然后恢复到原始形状。
简单理解:shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面
- input (Tensor) —— 输入张量。
- shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
- dims (int 或 tuple of python:int) 确定的维度。
Example:
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) >>> x tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> torch.roll(x, 1) tensor([[8, 1], [2, 3], [4, 5], [6, 7]]) '''第0维度向下移1位,多出的[7,8]补充到顶部''' >>> torch.roll(x, 1, 0) tensor([[7, 8], [1, 2], [3, 4], [5, 6]]) '''第0维度向上移1位,多出的[1,2]补充到底部''' >>> torch.roll(x, -1, 0) tensor([[3, 4], [5, 6], [7, 8], [1, 2]]) '''tuple元祖,维度一一对应: 第0维度向下移2位,多出的[5,6][7,8]补充到顶部, 第1维向右移1位,多出的[6,8,2,4]补充到最左边''' >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) tensor([[6, 5], [8, 7], [2, 1], [4, 3]])