鱼C论坛

 找回密码
 立即注册
查看: 1758|回复: 5

[已解决]无法pickle tensorflow 弱引用对象

[复制链接]
发表于 2023-3-31 21:08:55 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 125404629 于 2023-3-31 21:16 编辑

求助大神!,如能解决请加qq806559062,必有重谢!!!
下列代码是我所写的程序的一部分,功能是通过遗传算法优化CNN模型,实现光谱数据分类,其中,save_network(self)中的save_network函数的功能是同时保存network类的拓朴结构,所建立的模型和适应度,通过pickle.dump将其序列化为.obj格式的序列化文件,通过pickle.load将其反序列化用于读取network.但是现在存在的问题是,报错:TypeError:can't pickle weakref objects,无法pickle 弱引用对象,因此无法将network进行序列化。现在,请问如何修改代码实现序列化network类,所建立的模型和适应度到一个二进制文件之中,并能够再次反序列化读取。
环境配置:python 3.7.8 ,tensorflow -gpu 2.6.0 ,keras 2.6.0,之前使用tensorflow cpu进行训练时未出现这个问题。
下面是部分代码:
“def save_network(network):
    object_file = open(network.name + '.obj', 'wb')
    pickle.dump(network, object_file)
#将模型文件以二进制格式进行读取
def load_network(name):
    object_file = open(name + '.obj', 'rb')
    return pickle.load(object_file)
”
“class Network:
    __slots__ = ('name', 'block_list', 'fitness', 'model')
    def __init__(self, it):
        self.name = 'parent_' + str(it) if it == 0 else 'net_' + str(it)
        self.block_list = []
        self.fitness = None
        self.model = None
    def build_model(self):
        model = keras.Sequential()                                # create model
        for block in self.block_list:
            for layer in block.get_layers():                # build model
                try:
                    layer.build_layer(model)
                except:
                    print("\nINDIVIDUAL ABORTED, CREATING A NEW ONE\n")
                    return -1
        return model

    def train_and_evaluate(self, model, dataset):
        print("Training", self.name)
        model.compile(optimizer=model.opt, loss='categorical_crossentropy', metrics=['acc'])
        #需要增加优化器学习率超参数
        
        history = model.fit(dataset['x_train'],
                            dataset['y_train'],
                            batch_size=model.batch_size,
                            epochs=model.epochs,
                            validation_data=(dataset['x_test'], dataset['y_test']),
                            shuffle=True)
        self.model = model                                    # model
        self.fitness = history.history['val_acc'][-1]        # fitness
        model.save(self.name + '.h5')                       # save model训练数据和训练后的参数保存为h5格式文件
        save_network(self)
”
最佳答案
2023-3-31 21:37:17
执行这个看看结果:
print([eval('network.{}'.format(i)) for i in n.__slots__])
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2023-3-31 21:11:26 | 显示全部楼层
先把你的代码复制到记事本,然后再复制粘贴出来,而且点击 <> 用代码块

直接从pychram复制一堆格式问题
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2023-3-31 21:17:00 | 显示全部楼层
isdkz 发表于 2023-3-31 21:11
先把你的代码复制到记事本,然后再复制粘贴出来,而且点击  用代码块

直接从pychram复制一堆格式问题

好的
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2023-3-31 21:37:17 | 显示全部楼层    本楼为最佳答案   
执行这个看看结果:
print([eval('network.{}'.format(i)) for i in n.__slots__])
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2023-3-31 21:52:54 | 显示全部楼层
本帖最后由 125404629 于 2023-3-31 21:54 编辑

好的
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2023-3-31 23:14:09 | 显示全部楼层
通过重写 __getstate__ 和 __setstate__ 方法来自定义 pickle 的序列化和反序列化的过程
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 1 反对 0

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-11-14 21:04

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表