鱼C论坛

 找回密码
 立即注册
查看: 3714|回复: 2

关于计算PSNR的时候报错

[复制链接]
发表于 2020-2-28 04:27:27 | 显示全部楼层 |阅读模式
50鱼币
请大佬们看一下 为什么我图片处理后的格式不在区间内 怎样解决呢 谢谢
  1. import os
  2. import sys
  3. import time
  4. import math
  5. from datetime import datetime
  6. import random
  7. import logging
  8. from collections import OrderedDict
  9. import numpy as np
  10. import cv2
  11. import torch
  12. from torchvision.utils import make_grid
  13. from shutil import get_terminal_size

  14. import yaml
  15. try:
  16.     from yaml import CLoader as Loader, CDumper as Dumper
  17. except ImportError:
  18.     from yaml import Loader, Dumper


  19. def OrderedYaml():
  20.     '''yaml orderedDict support'''
  21.     _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG

  22.     def dict_representer(dumper, data):
  23.         return dumper.represent_dict(data.items())

  24.     def dict_constructor(loader, node):
  25.         return OrderedDict(loader.construct_pairs(node))

  26.     Dumper.add_representer(OrderedDict, dict_representer)
  27.     Loader.add_constructor(_mapping_tag, dict_constructor)
  28.     return Loader, Dumper


  29. ####################
  30. # miscellaneous
  31. ####################


  32. def get_timestamp():
  33.     return datetime.now().strftime('%y%m%d-%H%M%S')


  34. def mkdir(path):
  35.     if not os.path.exists(path):
  36.         os.makedirs(path)


  37. def mkdirs(paths):
  38.     if isinstance(paths, str):
  39.         mkdir(paths)
  40.     else:
  41.         for path in paths:
  42.             mkdir(path)


  43. def mkdir_and_rename(path):
  44.     if os.path.exists(path):
  45.         new_name = path + '_archived_' + get_timestamp()
  46.         print('Path already exists. Rename it to [{:s}]'.format(new_name))
  47.         logger = logging.getLogger('base')
  48.         logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
  49.         os.rename(path, new_name)
  50.     os.makedirs(path)


  51. def set_random_seed(seed):
  52.     random.seed(seed)
  53.     np.random.seed(seed)
  54.     torch.manual_seed(seed)
  55.     torch.cuda.manual_seed_all(seed)


  56. def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
  57.     '''set up logger'''
  58.     lg = logging.getLogger(logger_name)
  59.     formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
  60.                                   datefmt='%y-%m-%d %H:%M:%S')
  61.     lg.setLevel(level)
  62.     if tofile:
  63.         log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
  64.         fh = logging.FileHandler(log_file, mode='w')
  65.         fh.setFormatter(formatter)
  66.         lg.addHandler(fh)
  67.     if screen:
  68.         sh = logging.StreamHandler()
  69.         sh.setFormatter(formatter)
  70.         lg.addHandler(sh)


  71. ####################
  72. # image convert
  73. ####################


  74. def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
  75.     '''
  76.     Converts a torch Tensor into an image Numpy array
  77.     Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
  78.     Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
  79.     '''
  80.     tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
  81.     tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
  82.     n_dim = tensor.dim()
  83.     if n_dim == 4:
  84.         n_img = len(tensor)
  85.         img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
  86.         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
  87.     elif n_dim == 3:
  88.         img_np = tensor.numpy()
  89.         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
  90.     elif n_dim == 2:
  91.         img_np = tensor.numpy()
  92.     else:
  93.         raise TypeError(
  94.             'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
  95.     if out_type == np.uint8:
  96.         img_np = (img_np * 255.0).round()
  97.         # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
  98.     return img_np.astype(out_type)


  99. def save_img(img, img_path, mode='RGB'):
  100.     cv2.imwrite(img_path, img)


  101. ####################
  102. # metric
  103. ####################


  104. def calculate_psnr(img1, img2):
  105.     # img1 and img2 have range [0, 255]
  106.     img1 = img1.astype(np.float64)
  107.     img2 = img2.astype(np.float64)
  108.     mse = np.mean((img1 - img2)**2)
  109.     if mse == 0:
  110.         return float('inf')
  111.     return 20 * math.log10(255.0 / math.sqrt(mse))


  112. def ssim(img1, img2):
  113.     C1 = (0.01 * 255)**2
  114.     C2 = (0.03 * 255)**2

  115.     img1 = img1.astype(np.float64)
  116.     img2 = img2.astype(np.float64)
  117.     kernel = cv2.getGaussianKernel(11, 1.5)
  118.     window = np.outer(kernel, kernel.transpose())

  119.     mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
  120.     mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
  121.     mu1_sq = mu1**2
  122.     mu2_sq = mu2**2
  123.     mu1_mu2 = mu1 * mu2
  124.     sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
  125.     sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
  126.     sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

  127.     ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
  128.                                                             (sigma1_sq + sigma2_sq + C2))
  129.     return ssim_map.mean()


  130. def calculate_ssim(img1, img2):
  131.     '''calculate SSIM
  132.     the same outputs as MATLAB's
  133.     img1, img2: [0, 255]
  134.     '''
  135.     if not img1.shape == img2.shape:
  136.         raise ValueError('Input images must have the same dimensions.')
  137.     if img1.ndim == 2:
  138.         return ssim(img1, img2)
  139.     elif img1.ndim == 3:
  140.         if img1.shape[2] == 3:
  141.             ssims = []
  142.             for i in range(3):
  143.                 ssims.append(ssim(img1, img2))
  144.             return np.array(ssims).mean()
  145.         elif img1.shape[2] == 1:
  146.             return ssim(np.squeeze(img1), np.squeeze(img2))
  147.     else:
  148.         raise ValueError('Wrong input image dimensions.')


  149. class ProgressBar(object):
  150.     '''A progress bar which can print the progress
  151.     modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
  152.     '''

  153.     def __init__(self, task_num=0, bar_width=50, start=True):
  154.         self.task_num = task_num
  155.         max_bar_width = self._get_max_bar_width()
  156.         self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
  157.         self.completed = 0
  158.         if start:
  159.             self.start()

  160.     def _get_max_bar_width(self):
  161.         terminal_width, _ = get_terminal_size()
  162.         max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
  163.         if max_bar_width < 10:
  164.             print('terminal width is too small ({}), please consider widen the terminal for better '
  165.                   'progressbar visualization'.format(terminal_width))
  166.             max_bar_width = 10
  167.         return max_bar_width

  168.     def start(self):
  169.         if self.task_num > 0:
  170.             sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
  171.                 ' ' * self.bar_width, self.task_num, 'Start...'))
  172.         else:
  173.             sys.stdout.write('completed: 0, elapsed: 0s')
  174.         sys.stdout.flush()
  175.         self.start_time = time.time()

  176.     def update(self, msg='In progress...'):
  177.         self.completed += 1
  178.         elapsed = time.time() - self.start_time
  179.         fps = self.completed / elapsed
  180.         if self.task_num > 0:
  181.             percentage = self.completed / float(self.task_num)
  182.             eta = int(elapsed * (1 - percentage) / percentage + 0.5)
  183.             mark_width = int(self.bar_width * percentage)
  184.             bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
  185.             sys.stdout.write('\033[2F')  # cursor up 2 lines
  186.             sys.stdout.write('\033[J')  # clean the output (remove extra chars since last display)
  187.             sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
  188.                 bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
  189.         else:
  190.             sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
  191.                 self.completed, int(elapsed + 0.5), fps))
  192.         sys.stdout.flush()
复制代码

QQ截图20200228042619.png
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2020-2-29 20:22:35 | 显示全部楼层
代码太多,看不懂,能告诉我这是干啥的么?
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2020-2-29 20:34:30 | 显示全部楼层
看不懂
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2026-1-22 23:37

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表