鱼C论坛

 找回密码
 立即注册
查看: 1397|回复: 0

tensorflow实现自编码器(Autoencoder)用softmax分类时无准确率

[复制链接]
发表于 2022-3-12 00:30:42 | 显示全部楼层 |阅读模式
20鱼币
环境为python3.7,tensorflow-gpu 2.1

数据集为iris数据集,训练过程采用无监督学习,测试过程采用有监督学习
神经网络为最基本的自编码器:包含输入层,编码隐藏层,解码隐藏层,softmax层用于分类

训练过程的loss一直是在下降,但是测试过程中的预测准确率却从100%降到0,这是什么原因呢?

跪谢解答的大神!!!
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np

# 导入iris数据集
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target

# 打乱数据集
np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)

# 将打乱的数据集构造成训练集和测试集
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]

# 转换数据类型
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)

# 定义分批输入的训练集
train_db = tf.data.Dataset.from_tensor_slices(x_train, ).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# 定义超参数
input_size = 4
encoder_hidden = 3
decoder_hidden = 3
output_size = input_size

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
# 用tf.random.truncated_normal,如果x的取值在区间(μ-2σ,μ+2σ)之外则重新进行选择。这样保证了生成的值都在均值附近。
# tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
# shape: 一维的张量,也是输出的张量。 mean: 正态分布的均值。 stddev: 正态分布的标准差。 dtype: 输出的类型。
# seed: 一个整数,当设置之后,每次生成的随机数都一样。
# name: 操作的名字。
# 使用seed使每次生成的随机数相同(方便教学,使大家结果都一致,在现实使用时不写seed)
encoder_w1 = tf.Variable(tf.random.truncated_normal([input_size, encoder_hidden], stddev=0.1))
encoder_b1 = tf.Variable(tf.random.truncated_normal([encoder_hidden], stddev=0.1))

decoder_w1 = tf.Variable(tf.random.truncated_normal([encoder_hidden, output_size], stddev=0.1))
decoder_b1 = tf.Variable(tf.random.truncated_normal([output_size], stddev=0.1))

#dense_w1 = tf.Variable(tf.random.truncated_normal([output_size, output_size], stddev=0.1))
#dense_b1 = tf.Variable(tf.random.truncated_normal([output_size]), stddev=0.1)

lr = 0.09  # 学习率为0.1
train_loss_results = []  # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_loss_results = []
test_acc = []  # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 101  # 循环500轮
loss_train_all = 0  # 每轮分4个step,loss_all记录四个step生成的4个loss的和
loss_test_all = 0

# 训练部分
for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train) in enumerate(train_db):  # batch级别的循环,每个step循环一次batch
        with tf.GradientTape() as tape:  # wtih结构记录梯度信息
            encoder_hidden_train_output = tf.nn.sigmoid(tf.matmul(x_train, encoder_w1) + encoder_b1)
            decoder_hidden_train_output = tf.nn.sigmoid(tf.matmul(encoder_hidden_train_output, decoder_w1) + decoder_b1)

            # 使输出softmax_hidden_output符合概率分布(此操作后与独热码同量级,可相减求loss)
            #softmax_hidden_train_output = tf.nn.softmax(encoder_hidden_train_output)
            #y_train_prediction = softmax_hidden_train_output
            # 将标签转换成独热码格式,方便计算loss和accuracy
            #y_real = tf.one_hot(y_train, dtype=3)

            loss_train = tf.reduce_mean(tf.square(x_train - decoder_hidden_train_output))
            # 用loss.numpy()取出loss中的值
            # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确
            loss_train_all += loss_train.numpy()
        # 计算loss对各个参数的梯度
            grads = tape.gradient(loss_train, [encoder_w1, encoder_b1, decoder_w1, decoder_b1])

        # 梯度下降法
        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        # tf.assign_sub(ref, value, use_locking=None, name=None)
        # 变量 ref 减去 value值,即 ref = ref - value
        # ref:变量;value:值;use_locking,默认 False, 若为 True,则受锁保护;name,名称
        encoder_w1.assign_sub(lr * grads[0])  # 参数encoder_w1自更新
        encoder_b1.assign_sub(lr * grads[1])  # 参数encoder_b1自更新
        decoder_w1.assign_sub(lr * grads[2])  # 参数decoder_w1自更新
        decoder_b1.assign_sub(lr * grads[3])  # 参数decoder_b1自更新

    # 每个epoch打印loss信息
    print('Epoch:{}, loss_train:{}'.format(epoch, loss_train_all / 4))
    train_loss_results.append(loss_train_all / 4)  # 将8个step的loss求平均记录在train_loss_results中
    loss_train_all = 0  # loss_all归零,为记录下一个epoch的loss做准备

    # 测试部分
    # total_correct为预测对的样本数,total_number为测试的总样本数,将这两个变量初始化为0
    correct = 0
    total_correct, total_number = 0, 0
    for x_test, y_test in test_db:
        # 使用更新后的参数进行预测
        encoder_hidden_test_output = tf.nn.sigmoid(tf.matmul(x_test, encoder_w1) + encoder_b1)
        #decoder_hidden_test_output = tf.nn.sigmoid(tf.matmul(encoder_hidden_test_output, decoder_w1) + decoder_b1)
        softmax_hidden_test_output = tf.nn.softmax(encoder_hidden_test_output)

        # tf.argmax(array, axis)按行或列返回array中最大元素的索引值
        # axis=0表示跨行(经度,搜寻每一列的最大值的索引);axis=1表示跨列(纬度,搜寻每一行的最大值的索引)
        # axis不指定的话,所有元素参与运算
        y_test_prediction = tf.argmax(softmax_hidden_test_output, axis=1)  # 返回decoder_hidden_test_output最大值的索引,即预测的分类

        # 将y_test_prediction转换成y_test的数据类型
        y_test_prediction = tf.cast(y_test_prediction, dtype=y_test.dtype)

        # 若分类正确,则correct=1,否则为0,将bool型的结果转换成int型
        correct = tf.cast(tf.equal(y_test, y_test_prediction), dtype=tf.int32)

        # 将每个batch的correct数加起来
        correct = tf.reduce_sum(correct)
        # 将所有batch中的correct数加起来
        total_correct += int(correct)
        # total_number为测试的总样本数,即x_test的行数,shape[0]返回变量的行数
        total_number += x_test.shape[0]
    # 总的准确率等于total_correct/total_number
    acc = total_correct / total_number
    test_acc.append(acc)

    print("Test_acc:{0}, correct:{1}, total_correct:{2}, total_number:{3}".format(acc, correct, total_correct, total_number))
    print('--------------------------')

# 绘制loss_train曲线
plt.title("Loss_Train Function Curve")  # 图标题
plt.xlabel("Epoch")  # x轴名称
plt.ylabel("Loss_Train")  # y轴名称
plt.plot(train_loss_results, label='$Loss_train$')  # 逐点画出train_loss_results值并连线,连线图标是Loss_train
plt.legend()  # 画出曲线图标
plt.show()  # 画出图像

# 绘制Accuracy曲线
plt.title("Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.plot(test_acc, label='$Accuracy$')
plt.legend()
plt.show()

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-1-12 01:44

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表