Pytorch Image Models –timm快速使用

原文:Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide – 2022.02.02

中文教程: https://www.aiuai.cn/aifarm1967.html

Github: rwightman/pytorch-image-models

PyTorch Image Models(timm) 是一个优秀的图像分类 Python 库,其包含了大量的图像模型(Image Models)、Optimizers、Schedulers、Augmentations 等等.里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。

timm 提供了参考的 training 和 validation 脚本,用于复现在 ImageNet 上的训练结果;以及更多的 官方文档 和 timmdocs project.

timm的安装

关于timm的安装,我们可以选择以下两种方式进行:

  1. 通过pip安装
pip install timm
  1. 通过git与pip进行安装
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .

如何查看预训练模型种类

  1. 查看timm提供的预训练模型 截止到2022.3.27日为止,timm提供的预训练模型已经达到了592个,我们可以通过timm.list_models()方法查看timm提供的预训练模型(注:本章测试代码均是在jupyter notebook上进行)
import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
  1. 查看特定模型的所有种类 每一种系列可能对应着不同方案的模型,比如Resnet系列就包括了ResNet18,50,101等模型,我们可以在timm.list_models()传入想查询的模型名称(模糊查询),比如我们想查询densenet系列的所有模型。
all_densnet_models = timm.list_models("*densenet*")
all_densnet_models

我们发现以列表的形式返回了所有densenet系列的所有模型。

['densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'tv_densenet121']
  1. 查看模型的具体参数 当我们想查看下模型的具体参数的时候,我们可以通过访问模型的default_cfg属性来进行查看,具体操作如下
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
model.default_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bilinear',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv1',
 'classifier': 'fc',
 'architecture': 'resnet34'}

除此之外,我们可以通过访问这个链接 查看提供的预训练模型的准确度等信息。

使用和修改预训练模型

在得到我们想要使用的预训练模型后,我们可以通过timm.create_model()的方法来进行模型的创建,我们可以通过传入参数pretrained=True,来使用预训练模型。同样的,我们也可以使用跟torchvision里面的模型一样的方法查看模型的参数,类型/

import timm
import torch

model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 1000])
  • 查看某一层模型参数(以第一层卷积为例)
model = timm.create_model('resnet34',pretrained=True)
list(dict(model.named_children())['conv1'].parameters())
[Parameter containing:
 tensor([[[[-2.9398e-02, -3.6421e-02, -2.8832e-02,  ..., -1.8349e-02,
            -6.9210e-03,  1.2127e-02],
           [-3.6199e-02, -6.0810e-02, -5.3891e-02,  ..., -4.2744e-02,
            -7.3169e-03, -1.1834e-02],
            ...
           [ 8.4563e-03, -1.7099e-02, -1.2176e-03,  ...,  7.0081e-02,
             2.9756e-02, -4.1400e-03]]]], requires_grad=True)]
            
  • 修改模型(将1000类改为10类输出)
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 10])
  • 改变输入通道数(比如我们传入的图片是单通道的,但是模型需要的是三通道图片) 我们可以通过添加in_chans=1来改变
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
x = torch.randn(1,1,224,224)
output = model(x)

模型的保存

timm库所创建的模型是torch.model的子类,我们可以直接使用torch库中内置的模型参数保存和加载的方法,具体操作如下方代码所示

torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))

使用示例

# replace
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# with
optimizer = timm.optim.AdamP(model.parameters(), lr=0.01)

for epoch in num_epochs:
    for batch in training_dataloader:
        inputs, targets = batch
        outputs = model(inputs)
        loss = loss_function(outputs, targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        
#
optimizer = timm.optim.Adahessian(model.parameters(), lr=0.01)

is_second_order = (
    hasattr(optimizer, "is_second_order") and optimizer.is_second_order
)  # True

for epoch in num_epochs:
    for batch in training_dataloader:
        inputs, targets = batch
        outputs = model(inputs)
        loss = loss_function(outputs, targets)

        loss.backward(create_graph=second_order)
        optimizer.step()
        optimizer.zero_grad()

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注