如何使用dataloader加载相同维度但是不同尺寸的数据集(图片),不使用resize,crop等改变模型输入的shape。
知乎:https://www.zhihu.com/question/395888465
如果加载的数据的维度尺寸不相同的话,在迭代器中会爆出如下的错误
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0.
1、pytorch的dataloader默认的collate_fn会使用torch.stack合并多张图片成为batch
要么另外写一个collate_fn
要么在dataset类中对图片做padding,使得图片的size一样,可以直接stack
2、关于collate_fn:
https://pytorch.org/docs/stable/data.html#working-with-collate-fn
The use of collate_fn
is slightly different when automatic batching is enabled or disabled.
- When automatic batching is disabled,
collate_fn
is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the defaultcollate_fn
simply converts NumPy arrays in PyTorch tensors. - When automatic batching is enabled,
collate_fn
is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes behavior of the defaultcollate_fn
in this case.
可以看到,你可以考虑关闭自动打包,这样collate_fn处理的就是独立的样本。也可以打开自动打包,这样这个函数就会被输入一个batch列表的数据。注意,这个列表的数据可以不同大小哦,知识这样你就没办法将其stack
成一个完整的batch。所以,实际上你的报错,应该是这个位置出的问题。
所以可以考虑以下几种策略:
- 单个样本输入,这样同一个batch组合的时候就不需要担心了
- 对输入样本padding成最大的形状,组合成batch,之后送入网络的时候,你可以把数据拆分开,按你想要的将其去掉padding或者其他操作
- 正常读取,之后再自定义的collate_fn中将数据拆开返回,这样可以返回相同结构的数据
对于最后一点,给个小demo:
class OurDataset(Dataset):
def __init__(self, *tensors):
self.tensors = tensors
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def collate_wrapper(batch):
#函数就会输入一个batch的列表的数据(注意是batch是一个列表,所以里面的数据可以不同大小)
a, b = batch
return a, b
a = torch.randn(3, 2, 3)
b = torch.randn(3, 3, 4)
dataset = OurDataset(a, b)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper)
for sample in loader:
print([x.size() for x in sample])
# Out: [torch.Size([1, 3, 2, 3]), torch.Size([1, 3, 3, 4])]