鱼C论坛

 找回密码
 立即注册
查看: 3082|回复: 0

[学习笔记] torch模型保存格式及保存与加载方式

[复制链接]
发表于 2022-5-8 22:08:52 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 Handsome_zhou 于 2022-5-8 22:14 编辑
  1. state_dict = {'iter': iter,
  2.                     'encoder_state_dict': self.model.encoder.state_dict(),
  3.                     'decoder_state_dict': self.model.decoder.state_dict(),
  4.                     'reduce_state_dict': self.model.reduce_state_dict(),
  5.                     'optimizer': self.optimizer.state_dict(),
  6.                     'current_loss': running_avg_loss
  7.                    }

  8. model_save_path = os.path.join(self.model_dir, 'model_%d_%d.pth' % (iter, int(time.time()))
  9. torch.save(state_dict, model_save_path)
复制代码

#=============torch模型保存一般保存为.pth格式======================

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先
创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过
model.load_state_dict(dict)将模型的参数更新。

另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
参考:https://blog.csdn.net/qq_43219379/article/details/123675375
                https://www.cnblogs.com/xiaodai0/p/10413711.html

  1. state = torch.load(model_path, map_location=lambda storage, location: storage)
  2. self.encoder.load_state_dict(state['encoder_state_dict'])
  3. self.decoder.load_state_dict(state['decoder_state_dict'], strict = False)
  4. self.reduce_state.load_state_dict(state['reduce_state_dict'])
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-24 09:03

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表