鱼C论坛

 找回密码
 立即注册
查看: 3106|回复: 4

帮忙详细解释这段代码。

[复制链接]
发表于 2023-8-16 19:47:17 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
# 模型、损失和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithMoreAttention().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_psnr = 0
best_model_save_path = './save_root/best_autoencoder.pth'
model_save_path = './save_root/epoch{}_model.pth'
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.1)

num_epochs = 5000
if __name__ == '__main__':
    print(device)
    print("开始运行ing")

    for epoch in range(num_epochs):
        start_time = time()

        for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
            clean_images = clean_images.to(device)
            noisy_images = noisy_images.to(device)

            outputs = model(noisy_images)
            loss = criterion(outputs, clean_images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        save_combined_images(model, noisy_loader, clean_loader, output_path=f"./save_image/index{epoch}epoch.png")

        # 计算平均PSNR
        model.eval()
        total_psnr = 0
        with torch.no_grad():
            for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
                clean_images = clean_images.to(device)
                noisy_images = noisy_images.to(device)
                outputs = model(noisy_images)
                psnr = compute_psnr(clean_images, outputs)
                total_psnr += psnr

        avg_psnr = total_psnr / len(clean_loader)
        end_time = time()

        print(
            f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, PSNR: {avg_psnr:.4f}, Time:{end_time - start_time:.2f}")

        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), best_model_save_path)
            print(f"在第 {epoch + 1} 个epoch保存了最佳模型,PSNR为: {avg_psnr:.4f}")

        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), model_save_path.format(epoch))
            print(f"在第 {epoch + 1} 个epoch保存模型,PSNR为: {avg_psnr:.4f}")

        model.train()

    print("Training completed!")
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-10-21 03:17

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表