|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
# 模型、损失和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithMoreAttention().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_psnr = 0
best_model_save_path = './save_root/best_autoencoder.pth'
model_save_path = './save_root/epoch{}_model.pth'
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.1)
num_epochs = 5000
if __name__ == '__main__':
print(device)
print("开始运行ing")
for epoch in range(num_epochs):
start_time = time()
for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
clean_images = clean_images.to(device)
noisy_images = noisy_images.to(device)
outputs = model(noisy_images)
loss = criterion(outputs, clean_images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
save_combined_images(model, noisy_loader, clean_loader, output_path=f"./save_image/index{epoch}epoch.png")
# 计算平均PSNR
model.eval()
total_psnr = 0
with torch.no_grad():
for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
clean_images = clean_images.to(device)
noisy_images = noisy_images.to(device)
outputs = model(noisy_images)
psnr = compute_psnr(clean_images, outputs)
total_psnr += psnr
avg_psnr = total_psnr / len(clean_loader)
end_time = time()
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, PSNR: {avg_psnr:.4f}, Time:{end_time - start_time:.2f}")
if avg_psnr > best_psnr:
best_psnr = avg_psnr
torch.save(model.state_dict(), best_model_save_path)
print(f"在第 {epoch + 1} 个epoch保存了最佳模型,PSNR为: {avg_psnr:.4f}")
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), model_save_path.format(epoch))
print(f"在第 {epoch + 1} 个epoch保存模型,PSNR为: {avg_psnr:.4f}")
model.train()
print("Training completed!")
+-------------------+
| |
| DeeperAuto- |
| encoderWith |
| MoreAttention |
| |
+---------+---------+
|
|
+-----------------------------+------------^--------------+
| | |
| | |
+----v----+ +------v------+ +-------v--------+
| | | | | |
| device | | criterion | | optimizer |
| | | | | |
+----^----+ +-------------+ +-------+--------+
| |
| |
+----v----------------------------------------------------------v------+
| Model Training Loop |
| |
| +---------------------------------------------------+ |
| | | |
| | for epoch in range(num_epochs): | |
| | start_time = time() | |
| | for (clean_images, _), (noisy_images, _) in | |
| | zip(clean_loader, noisy_loader): | |
| | clean_images = clean_images.to(device) | |
| | noisy_images = noisy_images.to(device) | |
| | output = model(noisy_images) | |
| | loss = criterion(output, clean_images) | |
| | optimizer.zero_grad() | |
| | loss.backward() | |
| | optimizer.step() | |
| | scheduler.step() | |
| | save_combined_images(model, noisy_loader, | |
| | clean_loader, output_path=f"./save_image/index{epoch}epoch.png") | |
| | | |
| +---------------------------------------------------+ |
| |
| +---------------------------------------------------+ |
| | | |
| | model.eval() | |
| | total_psnr = 0 | |
| | with torch.no_grad(): | |
| | for (clean_images, _), (noisy_images, _) in | |
| | zip(clean_loader, noisy_loader): | |
| | clean_images = clean_images.to(device) | |
| | noisy_images = noisy_images.to(device) | |
| | output = model(noisy_images) | |
| | psnr = compute_psnr(clean_images, output) | |
| | total_psnr += psnr | |
| | | |
| +---------------------------------------------------+ |
| |
| +---------------------------------------------------+ |
| | | |
| | avg_psnr = total_psnr / len(clean_loader) | |
| | end_time = time() | |
| | | |
| | print( | |
| | f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, | |
| | PSNR: {avg_psnr:.4f}, Time: {end_time-start_time:.2f}") | |
| | | |
| | if avg_psnr > best_psnr: | |
| | best_psnr = avg_psnr | |
| | torch.save(model.state_dict(), best_model_save_path) | |
| | print(f"Saved the best model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}") | |
| | | |
| | if (epoch+1) % 10 == 0: | |
| | torch.save(model.state_dict(), model_save_path.format(epoch)) | |
| | print(f"Saved model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}") | |
| | | |
| | model.train() | |
| | | |
| +---------------------------------------------------+ |
| |
+------------------------------------------------------------------------+
求最佳答案
|
|