PyTorch实现断点继续训练

训练模型的保存包括两种:
1、保存整个模型框架以及模型参数(存储文件过大,不推荐)
torch.save(model,path)
2、仅仅保存模型的参数文件(推荐)
torch.save(model.state_dict(),path)
"state_dict"表示state dictionary,即字典类型的参数,模型本身的参数。
例如
torch.save(model.state_dict(),'{}/moilenetV2_{}_{}.pth'.format('./models',epoch,acc))


模型的断点继续训练
Resume = True
# Resume = False
if Resume:
	path_checkpoint = 'your/new/model/path.pth'
	checkpoint = torch.load(path_checkpoint, map_location = torch.device('cuda'))
	model.load_state_dict(checkpoint)

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

10 + 18 =