pytorch中模型的保存和读取:torch.load torch.save
1、读取tensor
我们可以直接使用save
函数和load
函数分别存储和读取Tensor
。save
使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save
可以保存各种对象,包括模型、张量和字典等。而load
使用pickle unpickle工具将pickle的对象文件反序列化为内存
import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')
x2 = torch.load('x.pt')\
存储一个Tensor列表并读回内存。
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
存储并读取一个从字符串映射到Tensor的字典。
torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
读写模型:
在PyTorch中,Module
的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()
访问)。state_dict
是一个从参数名称隐射到参数Tesnor
的字典对象。
1、将模型和参数都保存和读取
torch.save(model, PATH)
model = torch.load(PATH)
2、只存储模型参数(state_dict
)
torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth
加载:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))