鱼C论坛

 找回密码
 立即注册
查看: 1367|回复: 3

[已解决]请画出这段代码的结构图

[复制链接]
发表于 2023-8-16 21:36:14 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
# 模型、损失和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeeperAutoencoderWithMoreAttention().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
best_psnr = 0
best_model_save_path = './save_root/best_autoencoder.pth'
model_save_path = './save_root/epoch{}_model.pth'
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3000, gamma=0.1)

num_epochs = 5000
if __name__ == '__main__':
    print(device)
    print("开始运行ing")

    for epoch in range(num_epochs):
        start_time = time()

        for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
            clean_images = clean_images.to(device)
            noisy_images = noisy_images.to(device)

            outputs = model(noisy_images)
            loss = criterion(outputs, clean_images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        save_combined_images(model, noisy_loader, clean_loader, output_path=f"./save_image/index{epoch}epoch.png")

        # 计算平均PSNR
        model.eval()
        total_psnr = 0
        with torch.no_grad():
            for (clean_images, _), (noisy_images, _) in zip(clean_loader, noisy_loader):
                clean_images = clean_images.to(device)
                noisy_images = noisy_images.to(device)
                outputs = model(noisy_images)
                psnr = compute_psnr(clean_images, outputs)
                total_psnr += psnr

        avg_psnr = total_psnr / len(clean_loader)
        end_time = time()

        print(
            f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, PSNR: {avg_psnr:.4f}, Time:{end_time - start_time:.2f}")

        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), best_model_save_path)
            print(f"在第 {epoch + 1} 个epoch保存了最佳模型,PSNR为: {avg_psnr:.4f}")

        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), model_save_path.format(epoch))
            print(f"在第 {epoch + 1} 个epoch保存模型,PSNR为: {avg_psnr:.4f}")

        model.train()

    print("Training completed!")
最佳答案
2023-8-16 21:54:43
  1.                                       +-------------------+
  2.                                       |                   |
  3.                                       |    DeeperAuto-    |
  4.                                       |    encoderWith    |
  5.                                       |   MoreAttention  |
  6.                                       |                   |
  7.                                       +---------+---------+
  8.                                                 |
  9.                                                 |
  10.      +-----------------------------+------------^--------------+
  11.      |                             |                           |
  12.      |                             |                           |
  13. +----v----+                 +------v------+            +-------v--------+
  14. |         |                 |             |            |                |
  15. |  device |                 |  criterion  |            |   optimizer    |
  16. |         |                 |             |            |                |
  17. +----^----+                 +-------------+            +-------+--------+
  18.      |                                                          |
  19.      |                                                          |
  20. +----v----------------------------------------------------------v------+
  21. |                     Model Training Loop                                |
  22. |                                                                        |
  23. |    +---------------------------------------------------+               |
  24. |    |                                                   |               |
  25. |    |   for epoch in range(num_epochs):                  |               |
  26. |    |       start_time = time()                          |               |
  27. |    |       for (clean_images, _), (noisy_images, _) in    |               |
  28. |    |           zip(clean_loader, noisy_loader):          |               |
  29. |    |           clean_images = clean_images.to(device)    |               |
  30. |    |           noisy_images = noisy_images.to(device)    |               |
  31. |    |           output = model(noisy_images)              |               |
  32. |    |           loss = criterion(output, clean_images)    |               |
  33. |    |           optimizer.zero_grad()                     |               |
  34. |    |           loss.backward()                           |               |
  35. |    |           optimizer.step()                          |               |
  36. |    |           scheduler.step()                          |               |
  37. |    |           save_combined_images(model, noisy_loader,  |               |
  38. |    |               clean_loader, output_path=f"./save_image/index{epoch}epoch.png")  |               |
  39. |    |                                                   |               |
  40. |    +---------------------------------------------------+               |
  41. |                                                                        |
  42. |    +---------------------------------------------------+               |
  43. |    |                                                   |               |
  44. |    |   model.eval()                                    |               |
  45. |    |   total_psnr = 0                                 |               |
  46. |    |   with torch.no_grad():                           |               |
  47. |    |       for (clean_images, _), (noisy_images, _) in  |               |
  48. |    |           zip(clean_loader, noisy_loader):          |               |
  49. |    |           clean_images = clean_images.to(device)    |               |
  50. |    |           noisy_images = noisy_images.to(device)    |               |
  51. |    |           output = model(noisy_images)              |               |
  52. |    |           psnr = compute_psnr(clean_images, output) |               |
  53. |    |           total_psnr += psnr                       |               |
  54. |    |                                                   |               |
  55. |    +---------------------------------------------------+               |
  56. |                                                                        |
  57. |    +---------------------------------------------------+               |
  58. |    |                                                   |               |
  59. |    |   avg_psnr = total_psnr / len(clean_loader)        |               |
  60. |    |   end_time = time()                               |               |
  61. |    |                                                   |               |
  62. |    |   print(                                          |               |
  63. |    |       f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f},     |               |
  64. |    |       PSNR: {avg_psnr:.4f}, Time: {end_time-start_time:.2f}")  |               |
  65. |    |                                                   |               |
  66. |    |   if avg_psnr > best_psnr:                         |               |
  67. |    |       best_psnr = avg_psnr                         |               |
  68. |    |       torch.save(model.state_dict(), best_model_save_path)          |               |
  69. |    |       print(f"Saved the best model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}")  |               |
  70. |    |                                                   |               |
  71. |    |   if (epoch+1) % 10 == 0:                          |               |
  72. |    |       torch.save(model.state_dict(), model_save_path.format(epoch))  |               |
  73. |    |       print(f"Saved model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}")  |               |
  74. |    |                                                   |               |
  75. |    |   model.train()                                   |               |
  76. |    |                                                   |               |
  77. |    +---------------------------------------------------+               |
  78. |                                                                        |
  79. +------------------------------------------------------------------------+
复制代码

求最佳答案
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-8-16 21:36:38 | 显示全部楼层
抱歉,作为一个文本AI助手,我无法直接绘制图形。但是,我可以为您提供代码的层次结构说明。
  1. - 模型、损失和优化器的定义
  2. - 设备选择
  3. - 主循环开始
  4.   - 对每个 epoch 进行迭代
  5.     - 记录开始时间
  6.     - 对于每个数据批次
  7.       - 将干净图像和噪声图像放入设备中
  8.       - 运行模型得到输出
  9.       - 计算损失
  10.       - 清除梯度
  11.       - 反向传播并更新参数
  12.     - 更新学习率
  13.     - 保存组合图像
  14.     - 计算平均PSNR
  15.       - 对于每个数据批次
  16.         - 将干净图像和噪声图像放入设备中
  17.         - 运行模型得到输出
  18.         - 计算PSNR
  19.         - 累计PSNR值
  20.     - 计算平均PSNR
  21.     - 记录结束时间
  22.     - 打印当前epoch的损失、PSNR和运行时间
  23.     - 如果当前PSNR超过最佳PSNR,则保存模型
  24.     - 每10个epoch保存一次模型
  25.     - 将模型设置为训练模式
  26. - 输出训练完成的提示信息
复制代码

希望这能帮到您!如果您还有其他问题,请随时提问。

如果回答对你有帮助,请给我一个最佳答案!
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2023-8-16 21:38:51 | 显示全部楼层
这段代码是一个训练深度自编码器的程序。下面是代码的结构图:

  1.   ┌───────────────────────┐
  2.   │                       │
  3.   │    模型、损失和优化器   │
  4.   │                       │
  5.   └───────────────────────┘
  6.                 │
  7.                 ▼
  8.   ┌───────────────────────┐
  9.   │                       │
  10.   │       数据加载        │
  11.   │                       │
  12.   └───────────────────────┘
  13.                 │
  14.                 ▼
  15.   ┌───────────────────────┐
  16.   │                       │
  17.   │      模型训练循环      │
  18.   │                       │
  19.   └───────────────────────┘
  20.                 │
  21.                 ▼
  22.   ┌───────────────────────┐
  23.   │                       │
  24.   │   保存最佳模型和中间模型  │
  25.   │                       │
  26.   └───────────────────────┘
  27.                 │
  28.                 ▼
  29.   ┌───────────────────────┐
  30.   │                       │
  31.   │    训练完成,输出信息   │
  32.   │                       │
  33.   └───────────────────────┘
复制代码


在代码的主要部分,有一个循环来进行模型的训练。每个epoch内部,代码会遍历clean_loader和noisy_loader,将数据加载到设备上,然后通过模型生成输出,计算损失并进行优化器的反向传播。然后,代码会进行学习率的调整,并保存生成的合成图像。接下来,代码会计算平均PSNR,并保存最佳模型和中间模型。最后,代码会输出每个epoch的训练信息。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2023-8-16 21:54:43 | 显示全部楼层    本楼为最佳答案   
  1.                                       +-------------------+
  2.                                       |                   |
  3.                                       |    DeeperAuto-    |
  4.                                       |    encoderWith    |
  5.                                       |   MoreAttention  |
  6.                                       |                   |
  7.                                       +---------+---------+
  8.                                                 |
  9.                                                 |
  10.      +-----------------------------+------------^--------------+
  11.      |                             |                           |
  12.      |                             |                           |
  13. +----v----+                 +------v------+            +-------v--------+
  14. |         |                 |             |            |                |
  15. |  device |                 |  criterion  |            |   optimizer    |
  16. |         |                 |             |            |                |
  17. +----^----+                 +-------------+            +-------+--------+
  18.      |                                                          |
  19.      |                                                          |
  20. +----v----------------------------------------------------------v------+
  21. |                     Model Training Loop                                |
  22. |                                                                        |
  23. |    +---------------------------------------------------+               |
  24. |    |                                                   |               |
  25. |    |   for epoch in range(num_epochs):                  |               |
  26. |    |       start_time = time()                          |               |
  27. |    |       for (clean_images, _), (noisy_images, _) in    |               |
  28. |    |           zip(clean_loader, noisy_loader):          |               |
  29. |    |           clean_images = clean_images.to(device)    |               |
  30. |    |           noisy_images = noisy_images.to(device)    |               |
  31. |    |           output = model(noisy_images)              |               |
  32. |    |           loss = criterion(output, clean_images)    |               |
  33. |    |           optimizer.zero_grad()                     |               |
  34. |    |           loss.backward()                           |               |
  35. |    |           optimizer.step()                          |               |
  36. |    |           scheduler.step()                          |               |
  37. |    |           save_combined_images(model, noisy_loader,  |               |
  38. |    |               clean_loader, output_path=f"./save_image/index{epoch}epoch.png")  |               |
  39. |    |                                                   |               |
  40. |    +---------------------------------------------------+               |
  41. |                                                                        |
  42. |    +---------------------------------------------------+               |
  43. |    |                                                   |               |
  44. |    |   model.eval()                                    |               |
  45. |    |   total_psnr = 0                                 |               |
  46. |    |   with torch.no_grad():                           |               |
  47. |    |       for (clean_images, _), (noisy_images, _) in  |               |
  48. |    |           zip(clean_loader, noisy_loader):          |               |
  49. |    |           clean_images = clean_images.to(device)    |               |
  50. |    |           noisy_images = noisy_images.to(device)    |               |
  51. |    |           output = model(noisy_images)              |               |
  52. |    |           psnr = compute_psnr(clean_images, output) |               |
  53. |    |           total_psnr += psnr                       |               |
  54. |    |                                                   |               |
  55. |    +---------------------------------------------------+               |
  56. |                                                                        |
  57. |    +---------------------------------------------------+               |
  58. |    |                                                   |               |
  59. |    |   avg_psnr = total_psnr / len(clean_loader)        |               |
  60. |    |   end_time = time()                               |               |
  61. |    |                                                   |               |
  62. |    |   print(                                          |               |
  63. |    |       f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f},     |               |
  64. |    |       PSNR: {avg_psnr:.4f}, Time: {end_time-start_time:.2f}")  |               |
  65. |    |                                                   |               |
  66. |    |   if avg_psnr > best_psnr:                         |               |
  67. |    |       best_psnr = avg_psnr                         |               |
  68. |    |       torch.save(model.state_dict(), best_model_save_path)          |               |
  69. |    |       print(f"Saved the best model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}")  |               |
  70. |    |                                                   |               |
  71. |    |   if (epoch+1) % 10 == 0:                          |               |
  72. |    |       torch.save(model.state_dict(), model_save_path.format(epoch))  |               |
  73. |    |       print(f"Saved model at epoch {epoch+1}, PSNR: {avg_psnr:.4f}")  |               |
  74. |    |                                                   |               |
  75. |    |   model.train()                                   |               |
  76. |    |                                                   |               |
  77. |    +---------------------------------------------------+               |
  78. |                                                                        |
  79. +------------------------------------------------------------------------+
复制代码

求最佳答案
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-23 08:53

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表