|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
# 模型、损失和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithMoreAttention().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_psnr = 0
best_model_save_path = './save_root/best_autoencoder.pth'
model_save_path = './save_root/epoch{}_model.pth'
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.1)
num_epochs = 5000
if __name__ == '__main__':
print(device)
print("开始运行ing")
for epoch in range(num_epochs):
start_time = time()
for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
clean_images = clean_images.to(device)
noisy_images = noisy_images.to(device)
outputs = model(noisy_images)
loss = criterion(outputs, clean_images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
save_combined_images(model, noisy_loader, clean_loader, output_path=f"./save_image/index{epoch}epoch.png")
# 计算平均PSNR
model.eval()
total_psnr = 0
with torch.no_grad():
for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
clean_images = clean_images.to(device)
noisy_images = noisy_images.to(device)
outputs = model(noisy_images)
psnr = compute_psnr(clean_images, outputs)
total_psnr += psnr
avg_psnr = total_psnr / len(clean_loader)
end_time = time()
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, PSNR: {avg_psnr:.4f}, Time:{end_time - start_time:.2f}")
if avg_psnr > best_psnr:
best_psnr = avg_psnr
torch.save(model.state_dict(), best_model_save_path)
print(f"在第 {epoch + 1} 个epoch保存了最佳模型,PSNR为: {avg_psnr:.4f}")
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), model_save_path.format(epoch))
print(f"在第 {epoch + 1} 个epoch保存模型,PSNR为: {avg_psnr:.4f}")
model.train()
print("Training completed!") |
|