torch模型保存格式及保存与加载方式
本帖最后由 Handsome_zhou 于 2022-5-8 22:14 编辑state_dict = {'iter': iter,
'encoder_state_dict': self.model.encoder.state_dict(),
'decoder_state_dict': self.model.decoder.state_dict(),
'reduce_state_dict': self.model.reduce_state_dict(),
'optimizer': self.optimizer.state_dict(),
'current_loss': running_avg_loss
}
model_save_path = os.path.join(self.model_dir, 'model_%d_%d.pth' % (iter, int(time.time()))
torch.save(state_dict, model_save_path)
#=============torch模型保存一般保存为.pth格式======================
模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先
创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过
model.load_state_dict(dict)将模型的参数更新。
另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
参考:https://blog.csdn.net/qq_43219379/article/details/123675375
https://www.cnblogs.com/xiaodai0/p/10413711.html
state = torch.load(model_path, map_location=lambda storage, location: storage)
self.encoder.load_state_dict(state['encoder_state_dict'])
self.decoder.load_state_dict(state['decoder_state_dict'], strict = False)
self.reduce_state.load_state_dict(state['reduce_state_dict'])
页:
[1]