torch模型的保存和读取

pytorch中模型的保存和读取:torch.load torch.save

1、读取tensor

我们可以直接使用save函数和load函数分别存储和读取Tensorsave使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而load使用pickle unpickle工具将pickle的对象文件反序列化为内存

import torch
from torch import nn

x = torch.ones(3)
torch.save(x, 'x.pt')
x2 = torch.load('x.pt')\

存储一个Tensor列表并读回内存。
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

存储并读取一个从字符串映射到Tensor的字典。

torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

读写模型:

在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。

1、将模型和参数都保存和读取

torch.save(model, PATH)

model = torch.load(PATH)

2、只存储模型参数(state_dict)

torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth

加载:

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

圣诞快乐

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注