鱼C论坛

 找回密码
 立即注册
查看: 2728|回复: 0

[学习笔记] torch模型保存格式及保存与加载方式

[复制链接]
发表于 2022-5-8 22:08:52 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 Handsome_zhou 于 2022-5-8 22:14 编辑
state_dict = {'iter': iter,
                     'encoder_state_dict': self.model.encoder.state_dict(),
                    'decoder_state_dict': self.model.decoder.state_dict(),
                    'reduce_state_dict': self.model.reduce_state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'current_loss': running_avg_loss
                   }

model_save_path = os.path.join(self.model_dir, 'model_%d_%d.pth' % (iter, int(time.time()))
torch.save(state_dict, model_save_path)
#=============torch模型保存一般保存为.pth格式======================

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先
创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过
model.load_state_dict(dict)将模型的参数更新。

另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。
参考:https://blog.csdn.net/qq_43219379/article/details/123675375
                https://www.cnblogs.com/xiaodai0/p/10413711.html
state = torch.load(model_path, map_location=lambda storage, location: storage)
self.encoder.load_state_dict(state['encoder_state_dict'])
self.decoder.load_state_dict(state['decoder_state_dict'], strict = False)
self.reduce_state.load_state_dict(state['reduce_state_dict'])
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-12-27 10:35

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表