Handsome_zhou 发表于 2022-5-8 22:08:52

torch模型保存格式及保存与加载方式

本帖最后由 Handsome_zhou 于 2022-5-8 22:14 编辑

state_dict = {'iter': iter,
                  'encoder_state_dict': self.model.encoder.state_dict(),
                  'decoder_state_dict': self.model.decoder.state_dict(),
                  'reduce_state_dict': self.model.reduce_state_dict(),
                  'optimizer': self.optimizer.state_dict(),
                  'current_loss': running_avg_loss
                   }

model_save_path = os.path.join(self.model_dir, 'model_%d_%d.pth' % (iter, int(time.time()))
torch.save(state_dict, model_save_path)

#=============torch模型保存一般保存为.pth格式======================

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先
创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过
model.load_state_dict(dict)将模型的参数更新。

另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
参考:https://blog.csdn.net/qq_43219379/article/details/123675375
                https://www.cnblogs.com/xiaodai0/p/10413711.html

state = torch.load(model_path, map_location=lambda storage, location: storage)
self.encoder.load_state_dict(state['encoder_state_dict'])
self.decoder.load_state_dict(state['decoder_state_dict'], strict = False)
self.reduce_state.load_state_dict(state['reduce_state_dict'])
页: [1]
查看完整版本: torch模型保存格式及保存与加载方式