聚合国内IT技术精华文章,分享IT技术精华,帮助IT从业人士成长

Model saving error when using Apex

2022-06-03 09:56 浏览: 238676 次 我要评论(0 条) 字号:

Apex is a tool to enable mixed-precision training that comes from Nvidia.

import apex.amp as amp

net, optimizer = amp.initialize(net, optimizer, opt_level="O2")

# forward
outputs = net(inputs)

loss = criterion(outputs, targets)

optimizer.zero_grad()

# float16 backward
with amp.scale_loss(loss, optimizer) as scaled_loss:
  scaled_loss.backward()
  
optimizer.step()

...

torch.save(net, "model.pth")

After I changed my code to use Apex, it reported an error when saving the model by using torch.save(net, "model.pth")

AttributeError: Can't pickle local object '_initialize.<locals>.patch_forward.<locals>.new_fwd'

Someone has already noticed this problem but it seems no one wants to solve it: link. The only solution for this comes from a Chinese blog: link. It recommends just saving model parameters:

torch.save(net.state_dict(), "model.pth")


网友评论已有0条评论, 我也要评论

发表评论

*

* (保密)

Ctrl+Enter 快捷回复