聚合国内IT技术精华文章,分享IT技术精华,帮助IT从业人士成长

Average weights of two Pytorch models

2022-07-15 08:43 浏览: 273137 次 我要评论(0 条) 字号:

After reading this paper, I begin to do an experiment about it. Referencing this snippet, I wrote my code:

    net1 = model_builder.build_model()
    net2 = model_builder.build_model()
    output = model_builder.build_model()
    net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
    net2.load_state_dict(torch.load(args.model2, map_location="cpu"))
    
    # Average
    sd1 = net1.named_parameters()
    sd2 = net2.named_parameters()
    sdo = dict(sd2)
    for name, param in sd1:
        sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)

    output.load_state_dict(sdo)
    torch.save(output, args.output)
    
    # here is a test
    output.load_state_dict(torch.load(args.output))

But after generating the average-weights new model, the PyTorch failed to load it:

Traceback (most recent call last):
  File "average_models.py", line 43, in <module>
    output.load_state_dict(torch.load(args.output))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1534, in load_state_dict
    state_dict = state_dict.copy()
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1186, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'RegNet' object has no attribute 'copy'

The reason for failure is quite simple: we only need to save the state_dict of the model instead of all information (since I am using FP16 format ). Therefore the correct code should be:

    net1 = model_builder.build_model()
    net2 = model_builder.build_model()
    net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
    net2.load_state_dict(torch.load(args.model2, map_location="cpu"))

    # Average 
    sd1 = net1.named_parameters()
    sd2 = net2.named_parameters()
    sdo = dict(sd2) 
    for name, param in sd1:
        sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)

    torch.save(sdo, args.output)

BTW, the averaging of my models doesn’t rise accuracy as the paper suggests in my experiment.



网友评论已有0条评论, 我也要评论

发表评论

*

* (保密)

Ctrl+Enter 快捷回复