鱼C论坛

 找回密码
 立即注册
查看: 2536|回复: 7

[已解决]关于ESRGAN复现的问题

[复制链接]
发表于 2020-2-23 07:45:08 | 显示全部楼层 |阅读模式
50鱼币
复现train.py的时候报错 我个人认为可能是ymal文件的路径没有被读取 去options.py设置完好像依然找不到路径 我在网上看的读ymal文件的代码写法和作者有点出入 所以不太懂了 希望大神们不吝赐教
  1. import os
  2. import math
  3. import argparse
  4. import random
  5. import logging

  6. import torch
  7. import torch.distributed as dist
  8. import torch.multiprocessing as mp
  9. from data.data_sampler import DistIterSampler

  10. import options.options as option
  11. from utils import util
  12. from data import create_dataloader, create_dataset
  13. from models import create_model


  14. def init_dist(backend='nccl', **kwargs):
  15.     ''' initialization for distributed training'''
  16.     # if mp.get_start_method(allow_none=True) is None:
  17.     if mp.get_start_method(allow_none=True) != 'spawn':
  18.         mp.set_start_method('spawn')
  19.     rank = int(os.environ['RANK'])
  20.     num_gpus = torch.cuda.device_count()
  21.     torch.cuda.set_device(rank % num_gpus)
  22.     dist.init_process_group(backend=backend, **kwargs)


  23. def main():
  24.     #### options
  25.     parser = argparse.ArgumentParser()
  26.     parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
  27.     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
  28.                         help='job launcher')
  29.     parser.add_argument('--local_rank', type=int, default=0)
  30.     args = parser.parse_args()
  31.     opt = option.parse(args.opt, is_train=True)

  32.     #### distributed training settings
  33.     if args.launcher == 'none':  # disabled distributed training
  34.         opt['dist'] = False
  35.         rank = -1
  36.         print('Disabled distributed training.')
  37.     else:
  38.         opt['dist'] = True
  39.         init_dist()
  40.         world_size = torch.distributed.get_world_size()
  41.         rank = torch.distributed.get_rank()

  42.     #### loading resume state if exists
  43.     if opt['path'].get('resume_state', None):
  44.         # distributed resuming: all load into default GPU
  45.         device_id = torch.cuda.current_device()
  46.         resume_state = torch.load(opt['path']['resume_state'],
  47.                                   map_location=lambda storage, loc: storage.cuda(device_id))
  48.         option.check_resume(opt, resume_state['iter'])  # check resume options
  49.     else:
  50.         resume_state = None

  51.     #### mkdir and loggers
  52.     if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
  53.         if resume_state is None:
  54.             util.mkdir_and_rename(
  55.                 opt['path']['experiments_root'])  # rename experiment folder if exists
  56.             util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
  57.                          and 'pretrain_model' not in key and 'resume' not in key))

  58.         # config loggers. Before it, the log will not work
  59.         util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
  60.                           screen=True, tofile=True)
  61.         util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
  62.                           screen=True, tofile=True)
  63.         logger = logging.getLogger('base')
  64.         logger.info(option.dict2str(opt))
  65.         # tensorboard logger
  66.         if opt['use_tb_logger'] and 'debug' not in opt['name']:
  67.             version = float(torch.__version__[0:3])
  68.             if version >= 1.1:  # PyTorch 1.1
  69.                 from torch.utils.tensorboard import SummaryWriter
  70.             else:
  71.                 logger.info(
  72.                     'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
  73.                 from tensorboardX import SummaryWriter
  74.             tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
  75.     else:
  76.         util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
  77.         logger = logging.getLogger('base')

  78.     # convert to NoneDict, which returns None for missing keys
  79.     opt = option.dict_to_nonedict(opt)

  80.     #### random seed
  81.     seed = opt['train']['manual_seed']
  82.     if seed is None:
  83.         seed = random.randint(1, 10000)
  84.     if rank <= 0:
  85.         logger.info('Random seed: {}'.format(seed))
  86.     util.set_random_seed(seed)

  87.     torch.backends.cudnn.benchmark = True
  88.     # torch.backends.cudnn.deterministic = True

  89.     #### create train and val dataloader
  90.     dataset_ratio = 200  # enlarge the size of each epoch
  91.     for phase, dataset_opt in opt['datasets'].items():
  92.         if phase == 'train':
  93.             train_set = create_dataset(dataset_opt)
  94.             train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
  95.             total_iters = int(opt['train']['niter'])
  96.             total_epochs = int(math.ceil(total_iters / train_size))
  97.             if opt['dist']:
  98.                 train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
  99.                 total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
  100.             else:
  101.                 train_sampler = None
  102.             train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
  103.             if rank <= 0:
  104.                 logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
  105.                     len(train_set), train_size))
  106.                 logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
  107.                     total_epochs, total_iters))
  108.         elif phase == 'val':
  109.             val_set = create_dataset(dataset_opt)
  110.             val_loader = create_dataloader(val_set, dataset_opt, opt, None)
  111.             if rank <= 0:
  112.                 logger.info('Number of val images in [{:s}]: {:d}'.format(
  113.                     dataset_opt['name'], len(val_set)))
  114.         else:
  115.             raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
  116.     assert train_loader is not None

  117.     #### create model
  118.     model = create_model(opt)

  119.     #### resume training
  120.     if resume_state:
  121.         logger.info('Resuming training from epoch: {}, iter: {}.'.format(
  122.             resume_state['epoch'], resume_state['iter']))

  123.         start_epoch = resume_state['epoch']
  124.         current_step = resume_state['iter']
  125.         model.resume_training(resume_state)  # handle optimizers and schedulers
  126.     else:
  127.         current_step = 0
  128.         start_epoch = 0

  129.     #### training
  130.     logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
  131.     for epoch in range(start_epoch, total_epochs + 1):
  132.         if opt['dist']:
  133.             train_sampler.set_epoch(epoch)
  134.         for _, train_data in enumerate(train_loader):
  135.             current_step += 1
  136.             if current_step > total_iters:
  137.                 break
  138.             #### update learning rate
  139.             model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])

  140.             #### training
  141.             model.feed_data(train_data)
  142.             model.optimize_parameters(current_step)

  143.             #### log
  144.             if current_step % opt['logger']['print_freq'] == 0:
  145.                 logs = model.get_current_log()
  146.                 message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
  147.                     epoch, current_step, model.get_current_learning_rate())
  148.                 for k, v in logs.items():
  149.                     message += '{:s}: {:.4e} '.format(k, v)
  150.                     # tensorboard logger
  151.                     if opt['use_tb_logger'] and 'debug' not in opt['name']:
  152.                         if rank <= 0:
  153.                             tb_logger.add_scalar(k, v, current_step)
  154.                 if rank <= 0:
  155.                     logger.info(message)

  156.             # validation
  157.             if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
  158.                 avg_psnr = 0.0
  159.                 idx = 0
  160.                 for val_data in val_loader:
  161.                     idx += 1
  162.                     img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
  163.                     img_dir = os.path.join(opt['path']['val_images'], img_name)
  164.                     util.mkdir(img_dir)

  165.                     model.feed_data(val_data)
  166.                     model.test()

  167.                     visuals = model.get_current_visuals()
  168.                     sr_img = util.tensor2img(visuals['SR'])  # uint8
  169.                     gt_img = util.tensor2img(visuals['GT'])  # uint8

  170.                     # Save SR images for reference
  171.                     save_img_path = os.path.join(img_dir,
  172.                                                  '{:s}_{:d}.png'.format(img_name, current_step))
  173.                     util.save_img(sr_img, save_img_path)

  174.                     # calculate PSNR
  175.                     crop_size = opt['scale']
  176.                     gt_img = gt_img / 255.
  177.                     sr_img = sr_img / 255.
  178.                     cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
  179.                     cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
  180.                     avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

  181.                 avg_psnr = avg_psnr / idx

  182.                 # log
  183.                 logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
  184.                 logger_val = logging.getLogger('val')  # validation logger
  185.                 logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
  186.                     epoch, current_step, avg_psnr))
  187.                 # tensorboard logger
  188.                 if opt['use_tb_logger'] and 'debug' not in opt['name']:
  189.                     tb_logger.add_scalar('psnr', avg_psnr, current_step)

  190.             #### save models and training states
  191.             if current_step % opt['logger']['save_checkpoint_freq'] == 0:
  192.                 if rank <= 0:
  193.                     logger.info('Saving models and training states.')
  194.                     model.save(current_step)
  195.                     model.save_training_state(epoch, current_step)

  196.     if rank <= 0:
  197.         logger.info('Saving the final model.')
  198.         model.save('latest')
  199.         logger.info('End of training.')


  200. if __name__ == '__main__':
  201.     main()
复制代码

这是train.py的代码
最佳答案
2020-2-23 07:45:09
符明东233 发表于 2020-2-25 00:46
谢大佬 弱弱的问一句 我这是被CUDA劝退了吗

你的电脑上是没有配置任何的深度学习环境么?比如说cuda cudnn pytorch等环境?

这种代码都不是网上查一个库 下载下来就直接能用的,
(第0步,确定你电脑的gpu还可以,必须是英伟达的显卡,1060起步)
首先你需要先配置一个pytorch环境,网上有很多的教程。
然后跑一个“hello world”程序,测试你的pytorch环境是否安装成功。
之后才是根据你下载的代码进行运行测试,如果你是从github下载的代码,一般会在readme.md中介绍这个代码怎么运行,跟着他的说明,一步一步的再把代码跑起来。

报错

报错

最佳答案

查看完整内容

你的电脑上是没有配置任何的深度学习环境么?比如说cuda cudnn pytorch等环境? 这种代码都不是网上查一个库 下载下来就直接能用的, (第0步,确定你电脑的gpu还可以,必须是英伟达的显卡,1060起步) 首先你需要先配置一个pytorch环境,网上有很多的教程。 然后跑一个“hello world”程序,测试你的pytorch环境是否安装成功。 之后才是根据你下载的代码进行运行测试,如果你是从github下载的代码,一般会在readme.md中 ...
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2020-2-23 07:45:09 | 显示全部楼层    本楼为最佳答案   
符明东233 发表于 2020-2-25 00:46
谢大佬 弱弱的问一句 我这是被CUDA劝退了吗

你的电脑上是没有配置任何的深度学习环境么?比如说cuda cudnn pytorch等环境?

这种代码都不是网上查一个库 下载下来就直接能用的,
(第0步,确定你电脑的gpu还可以,必须是英伟达的显卡,1060起步)
首先你需要先配置一个pytorch环境,网上有很多的教程。
然后跑一个“hello world”程序,测试你的pytorch环境是否安装成功。
之后才是根据你下载的代码进行运行测试,如果你是从github下载的代码,一般会在readme.md中介绍这个代码怎么运行,跟着他的说明,一步一步的再把代码跑起来。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-2-23 07:46:19 | 显示全部楼层
  1. import os
  2. import os.path as osp
  3. import logging
  4. import yaml
  5. from utils.util import OrderedYaml
  6. Loader, Dumper = OrderedYaml()


  7. def parse(opt_path, is_train=True):
  8.     with open(opt_path, mode='r') as f:
  9.          opt = yaml.load(f, Loader=Loader)
  10.     # export CUDA_VISIBLE_DEVICES
  11.     gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
  12.     os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
  13.     print('export CUDA_VISIBLE_DEVICES=' + gpu_list)

  14.     opt['is_train'] = is_train
  15.     if opt['distortion'] == 'sr':
  16.         scale = opt['scale']

  17.     # datasets
  18.     for phase, dataset in opt['datasets'].items():
  19.         phase = phase.split('_')[0]
  20.         dataset['phase'] = phase
  21.         if opt['distortion'] == 'sr':
  22.             dataset['scale'] = scale
  23.         is_lmdb = False
  24.         if dataset.get('dataroot_GT', None) is not None:
  25.             dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
  26.             if dataset['dataroot_GT'].endswith('lmdb'):
  27.                 is_lmdb = True
  28.         # if dataset.get('dataroot_GT_bg', None) is not None:
  29.         #     dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg'])
  30.         if dataset.get('dataroot_LQ', None) is not None:
  31.             dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
  32.             if dataset['dataroot_LQ'].endswith('lmdb'):
  33.                 is_lmdb = True
  34.         dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
  35.         if dataset['mode'].endswith('mc'):  # for memcached
  36.             dataset['data_type'] = 'mc'
  37.             dataset['mode'] = dataset['mode'].replace('_mc', '')

  38.     # path
  39.     for key, path in opt['path'].items():
  40.         if path and key in opt['path'] and key != 'strict_load':
  41.             opt['path'][key] = osp.expanduser(path)
  42.     opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
  43.     if is_train:
  44.         experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
  45.         opt['path']['experiments_root'] = experiments_root
  46.         opt['path']['models'] = osp.join(experiments_root, 'models')
  47.         opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
  48.         opt['path']['log'] = experiments_root
  49.         opt['path']['val_images'] = osp.join(experiments_root, 'val_images')

  50.         # change some options for debug mode
  51.         if 'debug' in opt['name']:
  52.             opt['train']['val_freq'] = 8
  53.             opt['logger']['print_freq'] = 1
  54.             opt['logger']['save_checkpoint_freq'] = 8
  55.     else:  # test
  56.         results_root = osp.join(opt['path']['root'], 'results', opt['name'])
  57.         opt['path']['results_root'] = results_root
  58.         opt['path']['log'] = results_root

  59.     # network
  60.     if opt['distortion'] == 'sr':
  61.         opt['network_G']['scale'] = scale

  62.     return opt


  63. def dict2str(opt, indent_l=1):
  64.     '''dict to string for logger'''
  65.     msg = ''
  66.     for k, v in opt.items():
  67.         if isinstance(v, dict):
  68.             msg += ' ' * (indent_l * 2) + k + ':[\n'
  69.             msg += dict2str(v, indent_l + 1)
  70.             msg += ' ' * (indent_l * 2) + ']\n'
  71.         else:
  72.             msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
  73.     return msg


  74. class NoneDict(dict):
  75.     def __missing__(self, key):
  76.         return None


  77. # convert to NoneDict, which return None for missing key.
  78. def dict_to_nonedict(opt):
  79.     if isinstance(opt, dict):
  80.         new_opt = dict()
  81.         for key, sub_opt in opt.items():
  82.             new_opt[key] = dict_to_nonedict(sub_opt)
  83.         return NoneDict(**new_opt)
  84.     elif isinstance(opt, list):
  85.         return [dict_to_nonedict(sub_opt) for sub_opt in opt]
  86.     else:
  87.         return opt


  88. def check_resume(opt, resume_iter):
  89.     '''Check resume states and pretrain_model paths'''
  90.     logger = logging.getLogger('base')
  91.     if opt['path']['resume_state']:
  92.         if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
  93.                 'pretrain_model_D', None) is not None:
  94.             logger.warning('pretrain_model path will be ignored when resuming training.')

  95.         opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
  96.                                                    '{}_G.pth'.format(resume_iter))
  97.         logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
  98.         if 'gan' in opt['model']:
  99.             opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
  100.                                                        '{}_D.pth'.format(resume_iter))
  101.             logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
复制代码

这是options.py的
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2020-2-23 11:57:25 | 显示全部楼层
  1. opt = option.parse(args.opt, is_train=True)
复制代码

报错信息找到问题出在这里,说你第一个参数应该是一个string或其他类型的文件地址(比如:f:/xxx/xxx/xxx.txt),但是你传的参数是一个nonetype

那你就看看args.opt这个参数是怎么赋值的,
往前找能看到:
  1. parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
复制代码

那么问题应该就出在这里了,你在执行这个py文件的时候,应该是没有指定-opt这个参数。

比如你第一个文件叫main.py吧,那么执行的时候应该这么写:
  1. python3 main.py -opt "F:/xxx/你的文件地址.txt"
复制代码

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-2-23 15:33:53 | 显示全部楼层
shuofxz 发表于 2020-2-23 11:57
报错信息找到问题出在这里,说你第一个参数应该是一个string或其他类型的文件地址(比如:f:/xxx/xxx/xxx ...

谢谢您的答复 我用anaconda的命令行将您说的ymal的地址传到opt中了(我也不知道传的是不是对的)然后显示还是有错误 还劳烦请您看看
2.png
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2020-2-23 18:42:58 | 显示全部楼层
符明东233 发表于 2020-2-23 15:33
谢谢您的答复 我用anaconda的命令行将您说的ymal的地址传到opt中了(我也不知道传的是不是对的)然后显示 ...

-opt 后面加一个空格
  1. -opt "C:/xxxx"
复制代码
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-2-25 00:46:06 | 显示全部楼层
shuofxz 发表于 2020-2-23 18:42
-opt 后面加一个空格

谢大佬 弱弱的问一句 我这是被CUDA劝退了吗
123.png
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2020-2-25 06:29:01 | 显示全部楼层
你有一个未命名列表
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-25 15:09

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表