|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
# 模型、损失和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithMoreAttention().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_psnr = 0
best_model_save_path = './save_root/best_autoencoder.pth'
model_save_path = './save_root/epoch{}_model.pth'
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.1)
num_epochs = 5000
if __name__ == '__main__':
print(device)
print("开始运行ing")
for epoch in range(num_epochs):
start_time = time()
for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
clean_images = clean_images.to(device)
noisy_images = noisy_images.to(device)
outputs = model(noisy_images)
loss = criterion(outputs, clean_images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
save_combined_images(model, noisy_loader, clean_loader, output_path=f"./save_image/index{epoch}epoch.png")
# 计算平均PSNR
model.eval()
total_psnr = 0
with torch.no_grad():
for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
clean_images = clean_images.to(device)
noisy_images = noisy_images.to(device)
outputs = model(noisy_images)
psnr = compute_psnr(clean_images, outputs)
total_psnr += psnr
avg_psnr = total_psnr / len(clean_loader)
end_time = time()
print(
f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, PSNR: {avg_psnr:.4f}, Time:{end_time - start_time:.2f}")
if avg_psnr > best_psnr:
best_psnr = avg_psnr
torch.save(model.state_dict(), best_model_save_path)
print(f"在第 {epoch + 1} 个epoch保存了最佳模型,PSNR为: {avg_psnr:.4f}")
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), model_save_path.format(epoch))
print(f"在第 {epoch + 1} 个epoch保存模型,PSNR为: {avg_psnr:.4f}")
model.train()
print("Training completed!")
- +-------------------+
- | |
- | DeeperAuto- |
- | encoderWith |
- | MoreAttention |
- | |
- +---------+---------+
- |
- |
- +-----------------------------+------------^--------------+
- | | |
- | | |
- +----v----+ +------v------+ +-------v--------+
- | | | | | |
- | device | | criterion | | optimizer |
- | | | | | |
- +----^----+ +-------------+ +-------+--------+
- | |
- | |
- +----v----------------------------------------------------------v------+
- | Model Training Loop |
- | |
- | +---------------------------------------------------+ |
- | | | |
- | | for epoch in range(num_epochs): | |
- | | start_time = time() | |
- | | for (clean_images, _), (noisy_images, _) in | |
- | | zip(clean_loader, noisy_loader): | |
- | | clean_images = clean_images.to(device) | |
- | | noisy_images = noisy_images.to(device) | |
- | | output = model(noisy_images) | |
- | | loss = criterion(output, clean_images) | |
- | | optimizer.zero_grad() | |
- | | loss.backward() | |
- | | optimizer.step() | |
- | | scheduler.step() | |
- | | save_combined_images(model, noisy_loader, | |
- | | clean_loader, output_path=f"./save_image/index{epoch}epoch.png") | |
- | | | |
- | +---------------------------------------------------+ |
- | |
- | +---------------------------------------------------+ |
- | | | |
- | | model.eval() | |
- | | total_psnr = 0 | |
- | | with torch.no_grad(): | |
- | | for (clean_images, _), (noisy_images, _) in | |
- | | zip(clean_loader, noisy_loader): | |
- | | clean_images = clean_images.to(device) | |
- | | noisy_images = noisy_images.to(device) | |
- | | output = model(noisy_images) | |
- | | psnr = compute_psnr(clean_images, output) | |
- | | total_psnr += psnr | |
- | | | |
- | +---------------------------------------------------+ |
- | |
- | +---------------------------------------------------+ |
- | | | |
- | | avg_psnr = total_psnr / len(clean_loader) | |
- | | end_time = time() | |
- | | | |
- | | print( | |
- | | f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, | |
- | | PSNR: {avg_psnr:.4f}, Time: {end_time-start_time:.2f}") | |
- | | | |
- | | if avg_psnr > best_psnr: | |
- | | best_psnr = avg_psnr | |
- | | torch.save(model.state_dict(), best_model_save_path) | |
- | | print(f"Saved the best model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}") | |
- | | | |
- | | if (epoch+1) % 10 == 0: | |
- | | torch.save(model.state_dict(), model_save_path.format(epoch)) | |
- | | print(f"Saved model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}") | |
- | | | |
- | | model.train() | |
- | | | |
- | +---------------------------------------------------+ |
- | |
- +------------------------------------------------------------------------+
复制代码
求最佳答案
|
|