在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?
参考PyTorch官方的这份repo,我们知道有两种方法可以实现我们想要的效果。
方法一(推荐):
第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。
保存
1 | torch.save(the_model.state_dict(), PATH) |
恢复
1 2 | the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) |
使用这种方法,我们需要自己导入模型的结构信息。
方法二:
使用这种方法,将会保存模型的参数和结构信息。
保存
1 | torch.save(the_model, PATH) |
恢复
1 | the_model = torch.load(PATH) |
一个相对完整的例子
saving
1 2 3 4 5 6 | torch.save({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, 'checkpoint.tar' ) |
loading
1 2 3 4 5 6 7 8 9 | if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch'])) |
获取模型中某些层的参数
对于恢复的模型,如果我们想查看某些层的参数,可以:
1 2 3 4 5 6 7 8 9 10 | # 定义一个网络 from collections import OrderedDict model = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ])) # 打印网络的结构 print(model) |
Out:
1 2 3 4 5 6 | Sequential ( (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (relu1): ReLU () (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) (relu2): ReLU () ) |
如果我们想获取conv1的weight和bias:
1 2 3 4 5 | params=model.state_dict() for k,v in params.items(): print(k) #打印网络中的变量名 print(params['conv1.weight']) #打印conv1的weight print(params['conv1.bias']) #打印conv1的bias |
文章来源:http://www.aiboy.pub/2017/06/05/How_To_Save_And_Restore_Model/
本站微信群、QQ群(三群号 726282629):
print(params[‘conv1.weight’]) #打印conv1的weight
print(params[‘conv1.bias’]) #打印conv1的bias
请问这里在运行时报错
‘ print(params[‘conv1’]) #打印conv1的weight
KeyError: ‘conv1’’