| 
 | 
 
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册  
 
x
 
 本帖最后由 小诺爷 于 2022-6-3 12:11 编辑  
- from tensorflow.keras import datasets, layers, optimizers, Sequential,callbacks,losses
 
 - import os
 
 - import numpy as np
 
  
- np.set_printoptions(threshold=np.inf)
 
  
- (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data("mnist.pkl")
 
 - x_train, x_test = x_train / 255.0, x_test / 255.0
 
  
- model = Sequential([
 
 -     layers.InputLayer(input_shape=(28, 28)),
 
 -     layers.Reshape((28, 28, 1)),
 
 -     layers.Conv2D(kernel_size=3, strides=1, filters=16,padding='same', activation='relu', name='layer_conv1'),
 
 -     layers.MaxPooling2D(pool_size=2, strides=2),
 
 -     layers.BatchNormalization(),
 
 -     layers.Conv2D(kernel_size=3, strides=1, filters=36,padding='same', activation='relu', name='layer_conv2'),
 
 -     layers.MaxPooling2D(pool_size=2, strides=2),
 
 -     layers.BatchNormalization(),
 
 -     layers.Flatten(),
 
 -     layers.Dense(units=128, activation='selu'),
 
 -     layers.BatchNormalization(),
 
 -     layers.Dense(units=10, activation='softmax')
 
 - ])
 
  
 
- model.compile(optimizer=optimizers.Adam(lr=1e-3),loss=losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
 
  
- checkpoint_save_path = "./checkpoint/mnist.ckpt"
 
 - if os.path.exists(checkpoint_save_path + '.index'):
 
 -     print('-------------load the model-----------------')
 
 -     model.load_weights(checkpoint_save_path)
 
  
- cp_callback = callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
 
 -                                                  save_weights_only=True,
 
 -                                                  save_best_only=True)
 
 - model.fit(x=x_train, y=y_train,batch_size=32,epochs=5,validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
 
  
- model.summary()
 
 - file = open('weights.txt', 'w')
 
 - for v in model.trainable_variables:
 
 -     file.write(str(v.name) + '\n')
 
 -     file.write(str(v.shape) + '\n')
 
 -     file.write(str(v.numpy()) + '\n')
 
 - file.close()
 
  复制代码 
 |   
 
 
 
 |