from tensorflow.keras import datasets, layers, optimizers, Sequential,callbacks,losses
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data("mnist.pkl")
x_train, x_test = x_train / 255.0, x_test / 255.0
model = Sequential([
layers.InputLayer(input_shape=(28, 28)),
layers.Reshape((28, 28, 1)),
layers.Conv2D(kernel_size=3, strides=1, filters=16,padding='same', activation='relu', name='layer_conv1'),
layers.MaxPooling2D(pool_size=2, strides=2),
layers.BatchNormalization(),
layers.Conv2D(kernel_size=3, strides=1, filters=36,padding='same', activation='relu', name='layer_conv2'),
layers.MaxPooling2D(pool_size=2, strides=2),
layers.BatchNormalization(),
layers.Flatten(),
layers.Dense(units=128, activation='selu'),
layers.BatchNormalization(),
layers.Dense(units=10, activation='softmax')
])
model.compile(optimizer=optimizers.Adam(lr=1e-3),loss=losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
model.fit(x=x_train, y=y_train,batch_size=32,epochs=5,validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()
file = open('weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()