鱼C论坛

 找回密码
 立即注册
查看: 801|回复: 0

用Keras实现对抗训练+lstm的一个报错

[复制链接]
发表于 2020-8-6 09:15:40 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 猪仔很忙 于 2020-8-7 18:12 编辑

本贴之前误发到新手乐园 到这儿再发一次

我正在尝试建立结合对抗训练的LSTM模型,代码如下:
  1. def gradient_operation(args):
  2.     y_true = args[0]
  3.     y_pred = args[1]
  4.     v_final = args[2]
  5.     pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
  6.     pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
  7.     loss = -K.mean(0.75 * K.pow(1. - pt_1, 0) * K.log(pt_1)) - K.mean((1 - 0.75) * K.pow(pt_0, 0) * K.log(1. - pt_0))
  8.     perturb = K.gradients(loss, v_final)
  9.     adv_v_final = K.gradients(loss, v_final) + v_final
  10.     return adv_v_final

  11. def build_train(datas, machineID, modelName):
  12.     train_data, train_labels, val_data, val_labels, test_data, test_labels = datas

  13.     num_samples = train_data.shape[0]
  14.     time_step = train_data.shape[1]
  15.     feature_dim = train_data.shape[2]

  16.     trace_input = Input(shape=(time_step, feature_dim))
  17.     label_input = Input(shape=(1,))

  18.     if modelName == 'FEMT_LSTM':
  19.         v_final = RNN_SEPARATE_2(time_step, feature_dim)(trace_input)
  20.     else:
  21.         v_final = RNN(time_step, feature_dim)(trace_input)
  22.     pred = Dense(1, activation='sigmoid', name='pred')(v_final)
  23.     v_final_adv = Lambda(gradient_operation)([label_input, pred, v_final])
  24.     adv_pred = Dense(1, activation='sigmoid', name='adv_pred')(v_final_adv)
  25.     model = Model(inputs=[trace_input, label_input], outputs=[pred, adv_pred])
  26.     model.compile(optimizer=optimizers.Adam(lr=0.001, clipvalue=15),
  27.                   loss={'pred': myloss(alpha=0.75, gamma=0), 'adv_pred': myloss(alpha=0.75, gamma=0)},
  28.                   loss_weights={'pred': 1., 'adv_pred': 0.1})

  29.     print('Train...')
  30.     model_save_path = './lib/model_cp_rnn_{}'.format(machineID)

  31.     call_backs = [Call_back_0(valid_data=[val_data, val_labels, test_data, test_labels], # test_data, test_labels
  32.                               model_save_path=model_save_path),
  33.                   ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=4, mode='min'), # 4
  34.                   ModelCheckpoint(filepath=model_save_path, monitor='val_loss', save_best_only=True, save_weights_only=True, mode='min')]

  35.     model.fit({'trace_input': train_data, 'label_input': train_labels},
  36.             y=[train_labels, train_labels],
  37.             batch_size=64,
  38.             epochs=60, # 30 60
  39.             callbacks=call_backs,
  40.             validation_data=[val_data, val_labels])
复制代码

目标函数是:loss=(-1/N) Σ ylogy1 + (1-y)log(1-y1) + 0.01 * ((-1/N)Σ ylogy2 + (1-y)log(1-y2)),y是真实值,对应代码中的label_input、y_true;y_1是非对抗样本的预测结果,对应pred、y_pred;y_2是对抗样本的预测结果,对应adv_pred。

训练数据的每个样本是一个10*7矩阵的时序数据(暂时称为x),预测分类结果y_2 = f(x)。在模型的中间过程,我想通过求梯度制造一个噪声(见图中的adv_v_final),并将其与v_final相加,再用相加结果进行激活得到分类预测值。

为了避免出现NoneType' object has no attribute '_inbound_nodes'这个报错,我把求噪声的过程放入一个Lambda层。但是现在出现报错:
  1. Traceback (most recent call last):
  2.   File "D:\Anaconda3\envs\edogawaAi\lib\contextlib.py", line 130, in __exit__
  3.     self.gen.throw(type, value, traceback)
  4.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\ops.py", line 5652, in get_controller
  5.     yield g
  6.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\keras\engine\base_layer.py", line 474, in __call__
  7.     output_shape = self.compute_output_shape(input_shape)
  8.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\keras\layers\core.py", line 649, in compute_output_shape
  9.     x = self.call(xs)
  10.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\keras\layers\core.py", line 687, in call
  11.     return self.function(inputs, **arguments)
  12.   File "D:\PycharmProjects\turningpoint-master\model\rnn_turning_point.py", line 35, in gradient_operation
  13.     return K.gradients(loss, [v_final]) + v_final
  14.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\ops\math_ops.py", line 909, in r_binary_op_wrapper
  15.     x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
  16.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\ops.py", line 1087, in convert_to_tensor
  17.     return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
  18.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\ops.py", line 1145, in convert_to_tensor_v2
  19.     as_ref=False)
  20.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\ops.py", line 1224, in internal_convert_to_tensor
  21.     ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  22.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\constant_op.py", line 305, in _constant_tensor_conversion_function
  23.     return constant(v, dtype=dtype, name=name)
  24.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\constant_op.py", line 246, in constant
  25.     allow_broadcast=True)
  26.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\constant_op.py", line 284, in _constant_impl
  27.     allow_broadcast=allow_broadcast))
  28.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 466, in make_tensor_proto
  29.     _AssertCompatible(values, dtype)
  30.   File "D:\Anaconda3\envs\edogawaAi\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 371, in _AssertCompatible
  31.     (dtype.name, repr(mismatch), type(mismatch).__name__))
  32. TypeError: Expected float32, got None of type '_Message' instead.
复制代码

本人初学机器学习和python,除了这个报错,如果代码逻辑也有问题,还请各位不吝赐教:)
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-6-24 20:05

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表