鱼C论坛

 找回密码
 立即注册
查看: 2388|回复: 5

[已解决]无法pickle tensorflow 弱引用对象

[复制链接]
发表于 2023-3-31 21:08:55 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 125404629 于 2023-3-31 21:16 编辑

求助大神!,如能解决请加qq806559062,必有重谢!!!
下列代码是我所写的程序的一部分,功能是通过遗传算法优化CNN模型,实现光谱数据分类,其中,save_network(self)中的save_network函数的功能是同时保存network类的拓朴结构,所建立的模型和适应度,通过pickle.dump将其序列化为.obj格式的序列化文件,通过pickle.load将其反序列化用于读取network.但是现在存在的问题是,报错:TypeError:can't pickle weakref objects,无法pickle 弱引用对象,因此无法将network进行序列化。现在,请问如何修改代码实现序列化network类,所建立的模型和适应度到一个二进制文件之中,并能够再次反序列化读取。
环境配置:python 3.7.8 ,tensorflow -gpu 2.6.0 ,keras 2.6.0,之前使用tensorflow cpu进行训练时未出现这个问题。
下面是部分代码:
  1. “def save_network(network):
  2.     object_file = open(network.name + '.obj', 'wb')
  3.     pickle.dump(network, object_file)
  4. #将模型文件以二进制格式进行读取
  5. def load_network(name):
  6.     object_file = open(name + '.obj', 'rb')
  7.     return pickle.load(object_file)

  8. “class Network:
  9.     __slots__ = ('name', 'block_list', 'fitness', 'model')
  10.     def __init__(self, it):
  11.         self.name = 'parent_' + str(it) if it == 0 else 'net_' + str(it)
  12.         self.block_list = []
  13.         self.fitness = None
  14.         self.model = None
  15.     def build_model(self):
  16.         model = keras.Sequential()                                # create model
  17.         for block in self.block_list:
  18.             for layer in block.get_layers():                # build model
  19.                 try:
  20.                     layer.build_layer(model)
  21.                 except:
  22.                     print("\nINDIVIDUAL ABORTED, CREATING A NEW ONE\n")
  23.                     return -1
  24.         return model

  25.     def train_and_evaluate(self, model, dataset):
  26.         print("Training", self.name)
  27.         model.compile(optimizer=model.opt, loss='categorical_crossentropy', metrics=['acc'])
  28.         #需要增加优化器学习率超参数
  29.         
  30.         history = model.fit(dataset['x_train'],
  31.                             dataset['y_train'],
  32.                             batch_size=model.batch_size,
  33.                             epochs=model.epochs,
  34.                             validation_data=(dataset['x_test'], dataset['y_test']),
  35.                             shuffle=True)
  36.         self.model = model                                    # model
  37.         self.fitness = history.history['val_acc'][-1]        # fitness
  38.         model.save(self.name + '.h5')                       # save model训练数据和训练后的参数保存为h5格式文件
  39.         save_network(self)
复制代码
最佳答案
2023-3-31 21:37:17
执行这个看看结果:

  1. print([eval('network.{}'.format(i)) for i in n.__slots__])
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-3-31 21:11:26 | 显示全部楼层
先把你的代码复制到记事本,然后再复制粘贴出来,而且点击 <> 用代码块

直接从pychram复制一堆格式问题
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2023-3-31 21:17:00 | 显示全部楼层
isdkz 发表于 2023-3-31 21:11
先把你的代码复制到记事本,然后再复制粘贴出来,而且点击  用代码块

直接从pychram复制一堆格式问题

好的
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2023-3-31 21:37:17 | 显示全部楼层    本楼为最佳答案   
执行这个看看结果:

  1. print([eval('network.{}'.format(i)) for i in n.__slots__])
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2023-3-31 21:52:54 | 显示全部楼层
本帖最后由 125404629 于 2023-3-31 21:54 编辑

好的
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-3-31 23:14:09 | 显示全部楼层
通过重写 __getstate__ 和 __setstate__ 方法来自定义 pickle 的序列化和反序列化的过程
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 1 反对 0

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-6-28 16:35

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表