|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 糖逗 于 2021-1-7 11:56 编辑
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
#将数据划分为测试集和训练集
def preprocess(x, y):
x = tf.cast(x, dtype = tf.float64)
#x = x / tf.reduce_max(x)
y = tf.cast(y, dtype = tf.int64)
return x, y
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
data = load_breast_cancer()
x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2,
random_state = 11, stratify = data.target)
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
train_db = tf.data.Dataset.from_tensor_slices((np.array(x_train), y_train))
train_db = train_db.shuffle(123).map(preprocess).batch(20)
test_db = tf.data.Dataset.from_tensor_slices((np.array(x_test), y_test))
test_db = test_db.map(preprocess).batch(20)
sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,
tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
from tensorflow.keras import layers, optimizers
from tensorflow import keras
class FM(keras.Model):
def __init__(self, k = 4):
super(FM, self).__init__()
self.k = k
def build(self,input_shape):
self.fc = tf.keras.layers.Dense(units = 1,
bias_regularizer = tf.keras.regularizers.l2(0.01),
kernel_regularizer = tf.keras.regularizers.l1(0.02))
self.v = self.add_weight(shape = (input_shape[-1], self.k),
initializer = 'glorot_uniform',
trainable = True)
super(FM, self).build(input_shape)
def call(self, x, training=True):
#[1, dim]@[dim, k] = [1, k]
a = tf.pow(tf.matmul(x, self.v), 2)
#[1, dim] @[dim, k] = [1, k]
b = tf.matmul(tf.pow(x, 2), tf.pow(self.v, 2))
#[1, dim] @[dim, 1] + reduce_mean([1, k] - [1, k])
linear = self.fc(x)
add = tf.keras.layers.Add()([linear, tf.reduce_sum(a - b, axis = 1, keepdims = True)*0.5])
return tf.sigmoid(add)
model = FM()
model.build((None, 30))
model.summary()
def main():
model = FM()
model.compile(optimizer = keras.optimizers.Adam(1e-3),
loss = tf.keras.losses.binary_crossentropy,
metrics = [tf.keras.metrics.binary_accuracy])
model.fit(train_db, epochs = 200, validation_data = test_db)
model.evaluate(test_db)
if __name__ == '__main__':
main()
注意事项:
1.“super(FM, self).build(input_shape)”必须写,且要写在def的最后一行。 |
|