125404629 发表于 2023-3-31 21:08:55

无法pickle tensorflow 弱引用对象

本帖最后由 125404629 于 2023-3-31 21:16 编辑

求助大神!,如能解决请加qq806559062,必有重谢!!!
下列代码是我所写的程序的一部分,功能是通过遗传算法优化CNN模型,实现光谱数据分类,其中,save_network(self)中的save_network函数的功能是同时保存network类的拓朴结构,所建立的模型和适应度,通过pickle.dump将其序列化为.obj格式的序列化文件,通过pickle.load将其反序列化用于读取network.但是现在存在的问题是,报错:TypeError:can't pickle weakref objects,无法pickle 弱引用对象,因此无法将network进行序列化。现在,请问如何修改代码实现序列化network类,所建立的模型和适应度到一个二进制文件之中,并能够再次反序列化读取。
环境配置:python 3.7.8 ,tensorflow -gpu 2.6.0 ,keras 2.6.0,之前使用tensorflow cpu进行训练时未出现这个问题。
下面是部分代码:
“def save_network(network):
    object_file = open(network.name + '.obj', 'wb')
    pickle.dump(network, object_file)
#将模型文件以二进制格式进行读取
def load_network(name):
    object_file = open(name + '.obj', 'rb')
    return pickle.load(object_file)

“class Network:
    __slots__ = ('name', 'block_list', 'fitness', 'model')
    def __init__(self, it):
      self.name = 'parent_' + str(it) if it == 0 else 'net_' + str(it)
      self.block_list = []
      self.fitness = None
      self.model = None
    def build_model(self):
      model = keras.Sequential()                              # create model
      for block in self.block_list:
            for layer in block.get_layers():                # build model
                try:
                  layer.build_layer(model)
                except:
                  print("\nINDIVIDUAL ABORTED, CREATING A NEW ONE\n")
                  return -1
      return model

    def train_and_evaluate(self, model, dataset):
      print("Training", self.name)
      model.compile(optimizer=model.opt, loss='categorical_crossentropy', metrics=['acc'])
      #需要增加优化器学习率超参数
      
      history = model.fit(dataset['x_train'],
                            dataset['y_train'],
                            batch_size=model.batch_size,
                            epochs=model.epochs,
                            validation_data=(dataset['x_test'], dataset['y_test']),
                            shuffle=True)
      self.model = model                                    # model
      self.fitness = history.history['val_acc'][-1]      # fitness
      model.save(self.name + '.h5')                     # save model训练数据和训练后的参数保存为h5格式文件
      save_network(self)

isdkz 发表于 2023-3-31 21:11:26

先把你的代码复制到记事本,然后再复制粘贴出来,而且点击 <> 用代码块

直接从pychram复制一堆格式问题

125404629 发表于 2023-3-31 21:17:00

isdkz 发表于 2023-3-31 21:11
先把你的代码复制到记事本,然后再复制粘贴出来,而且点击用代码块

直接从pychram复制一堆格式问题

好的

isdkz 发表于 2023-3-31 21:37:17

执行这个看看结果:

print()

125404629 发表于 2023-3-31 21:52:54

本帖最后由 125404629 于 2023-3-31 21:54 编辑

好的

isdkz 发表于 2023-3-31 23:14:09

通过重写 __getstate__ 和 __setstate__ 方法来自定义 pickle 的序列化和反序列化的过程
页: [1]
查看完整版本: 无法pickle tensorflow 弱引用对象