代码片段-pytorch加载预训练模型的部分参数和设置不同的学习率
随着数据集越来越大,像ImageNet, COCO这种量级的数据大公司动辄上百张的GPU一起训练可能时间还不算啥,但对于大多数人来说,计算资源还是非常紧张的,所以最好的策略应该就是使用预训练的权重参数在自己的数据上进行微调。 而任务和数据的不同,导致我们的模型一定会做出相应的修改。如何从预训练的权重参数中抽取部分曾参数初始化我们自己的模型,以及如何给不同层设置学习率是这篇笔记的重点。
模型参数加载与保存
模型参数的保存
pytorch定义网络的参数保存有两种方式:
保存整个网络结构,以及网络参数
1
torch.save(net, 'net.pkl')
即将整个网络即对应参数保存到
net.pkl
文件中, 这一类显然会占用更多的存储空间以及存储时间。只保存模型的网络参数,而不保存结构,这时候相当于保存的是一个字典,键是每一层的名字,值则是对应的权重参数。(官方推荐方法)
1
torch.save(net.state_dict(), 'net_params.pkl')
net.state_dict()
相当于将模型参数字典话。
模型参数的加载
完整的模型参数加载,就是说如果类似于从之前训练的某一个断点进行resume或者inference,这是网络的结构没有发生变化,那么根据模型保存的方式不同,加载参数的方式也不同。分为两种。
加载模型结构以及对应参数
1
net = torch.load('model.pth')
加载保存的参数到已定义的网络结构中
1
net.load_state_dict(torch.load('net_params.pkl'))
模型参数裁剪
基于已有的模型构建适合自己数据的模型时,一般有两种做法。
直接在原有的网络上进行修改。这时候可以把原有的模型结构和参数同时加载下来,然后再修改某些特殊的层, 比如希望用ResNet分类10类数据,那么可能只需要修改最后一层的结构。
1
2net = torchvision.models.resnet18(pretrained=True)
net.fc = nn.Linear(512, 10)这时候前面resnet18的结构和参数都是保留的。可以通过设置不同的学习率来微调新的模型
自己定义好了模型,然后将与训练的模型参数加载到不同的层面。
- 这时候如果自己定义的模型和与训练的模型中公共的网络层,准确的说应该是state_dict中公共的键,如果对应的网络结构都是相同的,那么就可以直接通过宽泛的加载函数加载模型参数
1
net.load_state_dict(torch.load('net_params.pkl'), strict=False)
这里的
strict=False
表示如何键不相同那么就跳过该键的复制,但值得注意的是,如果两个网络有相同的键但结构不同,这时候不同采用这种方式加载参数,比如1
2
3
4
5
6
7class Net(nn.Module):
def __init__(self, d):
super(Net, self).__init__()
self.fc = nn.Linear(512, d)
net1, net2 = Net(20), Net(10)
net2.load_state_dict(net1.state_dict(), strict=False) #报错
# size mismatch for fc.weight:....- 第二种方式是将不希望加载的键值对直接剔除
1
2pretrained_params = {k:v for k, v in net1.state_dict().items() if k in net2.state_dict()}
net2.load_state_dict(pretrained_params)- 第三种方式处理不同GPU上的参数或者多卡的情形,可以通过重新赋值参数的方式实现
1
2
3
4
5
6
7
8
9
10
11def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
logger.info('missing keys:{}'.format(len(missing_keys)))
logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
logger.info('used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True1
2
3
4
5def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters share common prefix 'module.' '''
logger.info('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}1
2
3
4
5
6
7
8
9
10
11
12
13def load_pretrain(model, pretrained_path):
logger.info('load pretrained model from {}'.format(pretrained_path))
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
设置不同的学习率
设置不同的学习率问题也可以分为两类:
只对某些层进行学习,而其他的层固定,不参与训练, pytorch对每一个variable提供了requires_grad参数来确定参数是否需要更新。
1
2
3
4for m in net.children(): # 遍历模型的每一个模块
if isinstance(m, nn.Linear): # 设置满足的条件,满足条件时不更新参数
for param in net.parameters(): # 遍历模型的每一个参数, 可能包括weight和bias等
param.requires_grad=False当不同的层需要设置不同的学习率时,这部分主要是通过优化器的优化参数组来实现的, papams_group
1
2
3
4
5
6
7
8base_params = list(map(id, net.backbone.parameters())) # backbone中每个参数的id
fc_params = filter(lambda p: id(p) not in base_params, net.parameters())
# 筛选net.parameters中id不在base_params中的参数
params = [
{"params": fc_params, "lr": args.lr},
{"params": net.backbone.parameters(), "lr": args.backbone_lr}
]
optimizer = torch.optim.SGD(params, momentum=args.m, weight_decay=args.wd)
动态调整学习率
网络训练过程中一般学习率会越来越低,学习率的修改一般是通过修改optimizer的param的学习率实现的。
1 | for param_group in optimizer.param_groups: |
这里学习率的跟新可以采用多种更新方式,如log线性下降,或者阶段是下降等
官方提供的按照milestone调整学习率的方式。
1 | def adjust_learning_rate(optimizer, epoch, milestones=None): |
本文作者 : zhouzongwei
原文链接 : http://yoursite.com/2019/05/28/pytorch-params-load/
版权声明 : 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!
知识 & 情怀 | 赏或者不赏,我都在这,不声不响