{"id":3346,"date":"2022-03-31T17:25:24","date_gmt":"2022-03-31T09:25:24","guid":{"rendered":"http:\/\/139.9.1.231\/?p=3346"},"modified":"2022-04-27T14:49:14","modified_gmt":"2022-04-27T06:49:14","slug":"pytorch-modelsaveandload","status":"publish","type":"post","link":"http:\/\/139.9.1.231\/index.php\/2022\/03\/31\/pytorch-modelsaveandload\/","title":{"rendered":"PyTorch \u65ad\u70b9\u8bad\u7ec3\uff0c\u6a21\u578b\u7684\u4fdd\u5b58\u548c\u52a0\u8f7d"},"content":{"rendered":"\n\n\n<p>pytorch\u4e2d\u4e0e\u4fdd\u5b58\u548c\u52a0\u8f7d\u6a21\u578b\u6709\u5173\u51fd\u6570\u6709\u4e09\u4e2a\uff1a<br>1.torch.save:\u5c06\u5e8f\u5217\u5316\u7684\u5bf9\u8c61\u4fdd\u5b58\u5230\u78c1\u76d8\u3002\u6b64\u51fd\u6570\u4f7f\u7528Python\u7684pickle\u5b9e\u7528\u7a0b\u5e8f\u8fdb\u884c\u5e8f\u5217\u5316\u3002\u4f7f\u7528\u6b64\u529f\u80fd\u53ef\u4ee5\u4fdd\u5b58\u5404\u79cd\u5bf9\u8c61\u7684\u6a21\u578b\uff0c\u5f20\u91cf\u548c\u5b57\u5178\u3002<br>2. torch.load:\u4f7f\u7528pickle\u7684unpickle\u5de5\u5177\u5c06pickle\u7684\u5bf9\u8c61\u6587\u4ef6\u53cd\u5e8f\u5217\u5316\u5230\u5185\u5b58\u4e2d\u3002\u5373\u52a0\u8f7dsave\u4fdd\u5b58\u7684\u4e1c\u897f\u3002<br>3. torch.nn.Module.load_state_dict:\u4f7f\u7528\u53cd\u5e8f\u5217\u5316\u7684state_dict\u52a0\u8f7d\u6a21\u578b\u7684\u53c2\u6570\u5b57\u5178\u3002\u6ce8\u610f\uff0c\u8fd9\u610f\u5473\u7740\u5b83\u7684\u4f20\u5165\u7684\u53c2\u6570\u5e94\u8be5\u662f\u4e00\u4e2astate_dict\u7c7b\u578b\uff0c\u4e5f\u5c31torch.load\u52a0\u8f7d\u51fa\u6765\u7684\u3002<\/p>\n\n\n\n<h2>\u6a21\u578b\u642d\u5efa\uff1a<\/h2>\n\n\n\n<pre class=\"wp-block-code\"><code># Define model  \nclass TheModelClass(nn.Module):  \n    def __init__(self):  \n        super(TheModelClass, self).__init__()  \n        self.conv1 = nn.Conv2d(3, 6, 5)  \n        self.pool = nn.MaxPool2d(2, 2)  \n        self.conv2 = nn.Conv2d(6, 16, 5)  \n        self.fc1 = nn.Linear(16 * 5 * 5, 120)  \n        self.fc2 = nn.Linear(120, 84)  \n        self.fc3 = nn.Linear(84, 10)  \n  \n    def forward(self, x):  \n        x = self.pool(F.relu(self.conv1(x)))  \n        x = self.pool(F.relu(self.conv2(x)))  \n        x = x.view(-1, 16 * 5 * 5)  \n        x = F.relu(self.fc1(x))  \n        x = F.relu(self.fc2(x))  \n        x = self.fc3(x)  \n        return x  \n  \n# Initialize model  \nmodel = TheModelClass()  \n  \n# Initialize optimizer  \noptimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  \n  \n# Print model's state_dict  \nprint(\"Model's state_dict:\")  \nfor param_tensor in model.state_dict():  \n    print(param_tensor, \"\\t\", model.state_dict()&#091;param_tensor].size())  \n  \n# Print optimizer's state_dict  \nprint(\"Optimizer's state_dict:\")  \nfor var_name in optimizer.state_dict():  \n    print(var_name, \"\\t\", optimizer.state_dict()&#091;var_name])  <\/code><\/pre>\n\n\n\n<p>output\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>Model's state_dict:\nconv1.weight     torch.Size(&#091;6, 3, 5, 5])\nconv1.bias   torch.Size(&#091;6])\nconv2.weight     torch.Size(&#091;16, 6, 5, 5])\nconv2.bias   torch.Size(&#091;16])\nfc1.weight   torch.Size(&#091;120, 400])\nfc1.bias     torch.Size(&#091;120])\nfc2.weight   torch.Size(&#091;84, 120])\nfc2.bias     torch.Size(&#091;84])\nfc3.weight   torch.Size(&#091;10, 84])\nfc3.bias     torch.Size(&#091;10])\n\nOptimizer's state_dict:\nstate    {}\nparam_groups     &#091;{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'pa<\/code><\/pre>\n\n\n\n<h2>\u6062\u590d\u8bad\u7ec3\u5b9e\u4f8b<\/h2>\n\n\n\n<p>\u4fdd\u5b58\u6a21\u578b\u548c\u52a0\u8f7d\u6a21\u578b\u7684\u51fd\u6570\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def  save_checkpoint_state(dir,epoch,model,optimizer):\n\t#\u4fdd\u5b58\u6a21\u578b\n    checkpoint = {\n            'epoch': epoch,\n            'model_state_dict': model.state_dict(),\n            'optimizer_state_dict': optimizer.state_dict(),\n                }   \n    if not os.path.isdir(dir):\n        os.mkdir(dir)\n\n    torch.save(checkpoint, os.path.join(dir,'checkpoint-epoch%d.tar'%(epoch)))\n    \ndef get_checkpoint_state(dir,ckp_name,device,model,optimizer):\n     # \u6062\u590d\u4e0a\u6b21\u7684\u8bad\u7ec3\u72b6\u6001\n    print(\"Resume from checkpoint...\")\n    checkpoint = torch.load(os.path.join(dir,ckp_name),map_location=device)\n    model.load_state_dict(checkpoint&#091;'model_state_dict'])\n    epoch=checkpoint&#091;'epoch']\n\n    optimizer.load_state_dict(checkpoint&#091;'optimizer_state_dict'])\n    #scheduler.load_state_dict(checkpoint&#091;'scheduler_state_dict'])\n\n    print('sucessfully recover from the last state')\n    return model,epoch,optimizer\n<\/code><\/pre>\n\n\n\n<p>\u5982\u679c\u52a0\u5165\u4e86lr_scheduler\uff0c\u90a3\u4e48lr_scheduler\u7684state_dict\u4e5f\u8981\u52a0\u8fdb\u6765\u3002<\/p>\n\n\n\n<p>\u4f7f\u7528\u65f6\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># \u5f15\u7528\u5305\u7701\u7565\n#\u4fdd\u6301\u6a21\u578b\u51fd\u6570\ndef save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss):\n    checkpoint = {\n        \"epoch\": epoch,\n        \"model_state_dict\": model.state_dict(),\n        \"optimizer_state_dict\": optimizer.state_dict(),\n        \"scheduler_state_dict\": scheduler.state_dict()\n    }\n    \n    torch.save(checkpoint, \"checkpoint-epoch%d-loss%d.tar\" % (epoch, running_loss))\n# \u52a0\u8f7d\u6a21\u578b\u51fd\u6570   \ndef load_checkpoint_state(path, device, model, optimizer, scheduler):\n    checkpoint = torch.load(path, map_location=device)\n    \n    model.load_state_dict(checkpoint&#091;\"model_state_dict\"])\n    \n    epoch = checkpoint&#091;\"epoch\"]\n    \n    optimizer.load_state_dict(checkpoint&#091;\"optimizer_state_dict\"])\n    \n    scheduler.load_state_dict(checkpoint&#091;\"scheduler_state_dict\"])\n    \n    return model, epoch, optimizer, scheduler  \n\n\n# \u662f\u5426\u6062\u590d\u8bad\u7ec3\uff08\u5982\u679c\u662f\u6062\u590d\u8bad\u7ec3\uff0c\u90a3\u4e48\u9700\u8981\u8bbe\u7f6e\u4e3atrue\uff09\nresume = False # True\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ndef train(): \n    trans = transforms.Compose(&#091;\n        transforms.ToPILImage(),\n        transforms.RandomResizedCrop(512),\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomVerticalFlip(),\n        transforms.RandomRotation(90),\n        transforms.ToTensor(),\n        transforms.Normalize(mean=&#091;0.485, 0.456, 0.406], std=&#091;0.229, 0.224, 0.225])\n    ])\n    \n    # get training dataset\n    leafDiseaseCLS = CustomDataSet(images_path, is_to_ls, trans)\n    \n    data_loader = DataLoader(leafDiseaseCLS,\n                             batch_size=16,\n                             num_workers=0,\n                             shuffle=True,\n                             pin_memory=False)\n    \n    # get model\n    model = EfficientNet.from_pretrained(\"efficientnet-b3\")\n    \n    # extract the parameter of fully connected layer\n    fc_features = model._fc.in_features\n    # modify the number of classes\n    model._fc = nn.Linear(fc_features, 5)\n    \n    model.to(device)\n        \n    # optimizer\n    optimizer = optim.SGD(model.parameters(), \n                          lr=0.001, \n                          momentum=0.9,\n                          weight_decay=5e-4)\n    \n    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=&#091;6, 10], gamma=1\/3.)\n    \n    # loss\n    #loss_func = nn.CrossEntropyLoss()\n    loss_func = FocalCosineLoss()\n    \n    start_epoch = -1\n    \n    if resume:\n        model, start_epoch, optimizer,scheduler = load_checkpoint_state(\"..\/path\/to\/checkpoint.tar\",\n                                                                        device, \n                                                                        model,\n                                                                        optimizer,\n                                                                        scheduler)\n    \n    model.train()\n    \n    epochs = 3\n    \n    for epoch in range(start_epoch + 1, epochs):\n        \n        running_loss = 0.0\n        \n        print(\"Epoch {}\/{}\".format(epoch, epochs))\n        \n        for step, train_data in tqdm(enumerate(data_loader)):\n            x_train, y_train = train_data\n            \n            x_train = Variable(x_train.to(device))\n            y_train = Variable(y_train.to(device))\n            \n            # forward\n            prediction = model(x_train)\n            \n            optimizer.zero_grad()\n            \n            loss = loss_func(prediction, y_train)\n            \n            running_loss += loss.item()\n            \n            # backward\n            loss.backward()\n            \n            optimizer.step()            \n            \n            \n        scheduler.step()\n        \n        # saving model\n        torch.save(model.state_dict(), str(int(running_loss)) + \"_\" + str(epoch) + \".pth\")\n        \n        save_checkpoint_state(epoch, model, optimizer, scheduler, running_loss)\n        \n        print(\"Loss:{}\".format(running_loss))\n\nif __name__ == \"__main__\":\n    train()\n<\/code><\/pre>\n\n\n\n<h2>\u52a0\u8f7d\u90e8\u5206\u9884\u8bad\u7ec3\u6a21\u578b<\/h2>\n\n\n\n<p>\u5927\u591a\u6570\u65f6\u5019\u6211\u4eec\u9700\u8981\u6839\u636e\u6211\u4eec\u7684\u4efb\u52a1\u8c03\u8282\u6211\u4eec\u7684\u6a21\u578b\uff0c\u6240\u4ee5\u5f88\u96be\u4fdd\u8bc1\u6a21\u578b\u548c\u516c\u5f00\u7684\u6a21\u578b\u5b8c\u5168\u4e00\u6837\uff0c\u4f46\u662f\u9884\u8bad\u7ec3\u6a21\u578b\u7684\u53c2\u6570\u786e\u5b9e\u6709\u52a9\u4e8e\u63d0\u9ad8\u8bad\u7ec3\u7684\u51c6\u786e\u7387\uff0c\u4e3a\u4e86\u7ed3\u5408\u4e8c\u8005\u7684\u4f18\u70b9\uff0c\u5c31\u9700\u8981\u6211\u4eec\u52a0\u8f7d\u90e8\u5206\u9884\u8bad\u7ec3\u6a21\u578b\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>pretrained_dict = torch.load(\"model_data\/yolo_weights.pth\", map_location=device)\n\nmodel_dict = model.state_dict()\n# \u5c06 pretrained_dict \u91cc\u4e0d\u5c5e\u4e8e model_dict \u7684\u952e\u5254\u9664\u6389\npretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n#pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict&#091;k]) ==  np.shape(v)}\n# \u66f4\u65b0\u73b0\u6709\u7684 model_dict\nmodel_dict.update(pretrained_dict)\n# \u52a0\u8f7d\u6211\u4eec\u771f\u6b63\u9700\u8981\u7684 state_dict\nmodel.load_state_dict(model_dict)\n<\/code><\/pre>\n\n\n\n<h2>\u8de8\u8bbe\u5907\u4fdd\u5b58\/\u52a0\u8f7d\u6a21\u578b\uff08CPU\u4e0eGPU\uff09<\/h2>\n\n\n\n<h3><a><\/a>\u6a21\u578b\u4fdd\u5b58\u5728GPU\u4e0a\uff0c\u52a0\u8f7d\u5230CPU<\/h3>\n\n\n\n<ul><li>\u4fdd\u5b58<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>torch.save(model.state_dict(), PATH)\n<\/code><\/pre>\n\n\n\n<ul><li>\u52a0\u8f7d:<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>device = torch.device('cpu')\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, map_location=device))\n<\/code><\/pre>\n\n\n\n<h2><a><\/a>\u6a21\u578b\u4fdd\u5b58\u5728GPU\u4e0a\uff0c\u52a0\u8f7d\u5230GPU<\/h2>\n\n\n\n<p><a><\/a>\u4fdd\u5b58:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>torch.save(model.state_dict(), PATH)\n<\/code><\/pre>\n\n\n\n<ul><li>\u52a0\u8f7d:<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>device = torch.device(\"cuda\")\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH))\nmodel.to(device)\n<em># Make sure to call input = input.to(device) on any input tensors that you feed to the model<\/em><\/code><\/pre>\n\n\n\n<p><\/p>\n\n\n\n<h2>\u91cd\u70b9\uff1a\u5728\u4e8eepoch\u7684\u6062\u590d<\/h2>\n\n\n\n<p>\u4fdd\u5b58\u7684\u65f6\u5019\u9700\u8981\u5c06 epoch\u4e5f\u4fdd\u5b58<\/p>\n\n\n\n<p>\u4ee3\u7801\uff1a\u5b9e\u73b0\u6bcf\u9694N\u4e2aepoch\uff0csave\u6a21\u578b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>optimizer = torch.optim.SGD(model.parameters(),lr=0.1)\nlr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=&#091;10,20,30,40,50],gamma=0.1)\nstart_epoch = 9\n<em># print(schedule)<\/em>\n\n\nif RESUME:\n    path_checkpoint = \".\/model_parameter\/test\/ckpt_best_50.pth\"  <em># \u65ad\u70b9\u8def\u5f84<\/em>\n    checkpoint = torch.load(path_checkpoint)  <em># \u52a0\u8f7d\u65ad\u70b9<\/em>\n\n    model.load_state_dict(checkpoint&#091;'net'])  <em># \u52a0\u8f7d\u6a21\u578b\u53ef\u5b66\u4e60\u53c2\u6570<\/em>\n\n    optimizer.load_state_dict(checkpoint&#091;'optimizer'])  <em># \u52a0\u8f7d\u4f18\u5316\u5668\u53c2\u6570<\/em>\n    start_epoch = checkpoint&#091;'epoch']  <em># \u8bbe\u7f6e\u5f00\u59cb\u7684epoch<\/em>\n    lr_schedule.load_state_dict(checkpoint&#091;'lr_schedule'])\n\nfor epoch in range(start_epoch+1,80):\n\n    optimizer.zero_grad()\n\n    optimizer.step()\n    lr_schedule.step()\n\n\n    if epoch %10 ==0:\n        print('epoch:',epoch)\n        print('learning rate:',optimizer.state_dict()&#091;'param_groups']&#091;0]&#091;'lr'])\n        checkpoint = {\n            \"net\": model.state_dict(),\n            'optimizer': optimizer.state_dict(),\n            \"epoch\": epoch,\n            'lr_schedule': lr_schedule.state_dict()\n        }\n        if not os.path.isdir(\".\/model_parameter\/test\"):\n            os.mkdir(\".\/model_parameter\/test\")\n        torch.save(checkpoint, '.\/model_parameter\/test\/ckpt_best_%s.pth' % (str(epoch)))<\/code><\/pre>\n\n\n\n<h2> \u8bbe\u7f6e\u968f\u673a\u6570\u79cd\u5b50 \uff0c\u4f7f\u5f97\u8bad\u7ec3\u8fc7\u7a0b\u7ed3\u679c\u53ef\u590d\u73b0<\/h2>\n\n\n\n<p>PyTorch\u65f6\uff0c\u5982\u679c\u5e0c\u671b\u901a\u8fc7\u8bbe\u7f6e\u968f\u673a\u6570\u79cd\u5b50\uff0c\u5728gpu\u6216cpu\u4e0a\u56fa\u5b9a\u6bcf\u4e00\u6b21\u7684\u8bad\u7ec3\u7ed3\u679c\uff0c\u5219\u9700\u8981\u5728\u7a0b\u5e8f\u6267\u884c\u7684\u5f00\u59cb\u5904\u6dfb\u52a0\u4ee5\u4e0b\u4ee3\u7801\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def setup_seed(seed):\n     torch.manual_seed(seed)\n     torch.cuda.manual_seed_all(seed)\n     np.random.seed(seed)\n     random.seed(seed)\n     torch.backends.cudnn.deterministic = True\n<em># \u8bbe\u7f6e\u968f\u673a\u6570\u79cd\u5b50<\/em>\nsetup_seed(20)\n<em># \u9884\u5904\u7406\u6570\u636e\u4ee5\u53ca\u8bad\u7ec3\u6a21\u578b<\/em>\n<em># ...<\/em>\n<em># ...<\/em><\/code><\/pre>\n\n\n\n<p class=\"has-light-pink-background-color has-background\">\u968f\u673a\u6570\u79cd\u5b50seed\u786e\u5b9a\u65f6\uff0c\u4e0d\u6539\u53d8\u7a0b\u5e8f\u53c2\u6570\u60c5\u51b5\u4e0b\uff0c\u4e24\u6b21\u6a21\u578b\u7684\u8bad\u7ec3\u7ed3\u679c\u5c06\u59cb\u7ec8\u4fdd\u6301\u4e00\u81f4\u3002<\/p>\n","protected":false},"excerpt":{"rendered":"<p>pytorch\u4e2d\u4e0e\u4fdd\u5b58\u548c\u52a0\u8f7d\u6a21\u578b\u6709\u5173\u51fd\u6570\u6709\u4e09\u4e2a\uff1a1.torch.save:\u5c06\u5e8f\u5217\u5316\u7684\u5bf9\u8c61\u4fdd\u5b58\u5230\u78c1\u76d8\u3002\u6b64\u51fd\u6570\u4f7f\u7528 &hellip; <a href=\"http:\/\/139.9.1.231\/index.php\/2022\/03\/31\/pytorch-modelsaveandload\/\" class=\"more-link\">\u7ee7\u7eed\u9605\u8bfb<span class=\"screen-reader-text\">PyTorch \u65ad\u70b9\u8bad\u7ec3\uff0c\u6a21\u578b\u7684\u4fdd\u5b58\u548c\u52a0\u8f7d<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":[],"categories":[11],"tags":[],"_links":{"self":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/3346"}],"collection":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/comments?post=3346"}],"version-history":[{"count":16,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/3346\/revisions"}],"predecessor-version":[{"id":3362,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/3346\/revisions\/3362"}],"wp:attachment":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/media?parent=3346"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/categories?post=3346"},{"taxonomy":"post_tag","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/tags?post=3346"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}