Python实现FM【tensorflow2.0】
本帖最后由 糖逗 于 2021-1-7 11:56 编辑import tensorflow as tf
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
#将数据划分为测试集和训练集
def preprocess(x, y):
x = tf.cast(x, dtype = tf.float64)
#x = x / tf.reduce_max(x)
y = tf.cast(y, dtype = tf.int64)
return x, y
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
data = load_breast_cancer()
x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2,
random_state = 11, stratify = data.target)
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
train_db = tf.data.Dataset.from_tensor_slices((np.array(x_train), y_train))
train_db = train_db.shuffle(123).map(preprocess).batch(20)
test_db = tf.data.Dataset.from_tensor_slices((np.array(x_test), y_test))
test_db = test_db.map(preprocess).batch(20)
sample = next(iter(train_db))
print('sample:', sample.shape, sample.shape,
tf.reduce_min(sample), tf.reduce_max(sample))
from tensorflow.keras import layers, optimizers
from tensorflow import keras
class FM(keras.Model):
def __init__(self, k = 4):
super(FM, self).__init__()
self.k = k
def build(self,input_shape):
self.fc = tf.keras.layers.Dense(units = 1,
bias_regularizer = tf.keras.regularizers.l2(0.01),
kernel_regularizer = tf.keras.regularizers.l1(0.02))
self.v = self.add_weight(shape = (input_shape[-1], self.k),
initializer = 'glorot_uniform',
trainable = True)
super(FM, self).build(input_shape)
def call(self, x, training=True):
#@ =
a = tf.pow(tf.matmul(x, self.v), 2)
# @ =
b = tf.matmul(tf.pow(x, 2), tf.pow(self.v, 2))
# @ + reduce_mean( - )
linear = self.fc(x)
add = tf.keras.layers.Add()()
return tf.sigmoid(add)
model = FM()
model.build((None, 30))
model.summary()
def main():
model = FM()
model.compile(optimizer = keras.optimizers.Adam(1e-3),
loss = tf.keras.losses.binary_crossentropy,
metrics = )
model.fit(train_db, epochs = 200, validation_data = test_db)
model.evaluate(test_db)
if __name__ == '__main__':
main()
注意事项:
1.“super(FM, self).build(input_shape)”必须写,且要写在def的最后一行。 {:10_298:} 参考链接:https://blog.csdn.net/ganghaodream/article/details/98964903 论文:http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.393.8529&rep=rep1&type=pdf
代码由网络资源整合学习后写成,感谢网络中乐于进行学习分享的大家{:10_298:}
糖逗 发表于 2021-1-7 11:57
论文:http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.393.8529&rep=rep1&type=pdf
代码由 ...
感谢感谢
页:
[1]