|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
源码:
import torch
import torch.nn as nn
from metric import get_stoi, get_pesq
from scipy.io import wavfile
import numpy as np
from checkpoints import Checkpoint
from torch.utils.data import DataLoader
from helper_funcs import snr, numParams
from eval_composite import eval_composite
from AudioData import EvalDataset, EvalCollate
from new_model import Net
import h5py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sr = 16000
#file_name = 'psquare_17.5'
#test_file_list_path = '/media/concordia/DATA/KaiWang/pytorch_learn/pytorch_for_speech/voice_bank/Transformer/v5/test_file_break' + '/' + file_name
#audio_file_save = 'D:/pycharmProject/TSTNN-master/Mydataset/enhanced_audio' + '/' + 'enhanced_' + file_name
test_file_list_path = "D:/pycharmProject/TSTNN-master/test_file_list"
audio_file_save = "D:/pycharmProject/TSTNN-master/Mydataset/enhanced_audio/"
if not os.path.isdir(audio_file_save):
os.makedirs(audio_file_save)
with open(test_file_list_path, 'r') as test_file_list:
file_list = [line.strip() for line in test_file_list.readlines()]
#audio_name = os.path.basename(file_list[0])
print(file_list)
test_data = EvalDataset(test_file_list_path, frame_size=512, frame_shift=256)
test_loader = DataLoader(test_data,
batch_size=1,
shuffle=False,
num_workers=4,
collate_fn=EvalCollate())
ckpt_path = 'D:/pycharmProject/TSTNN-master/checkpoints/best.model'
model = Net()
model = nn.DataParallel(model, device_ids=[0])
checkpoint = Checkpoint()
checkpoint.load(ckpt_path)
model.load_state_dict(checkpoint.state_dict)
model.cuda()
print(checkpoint.start_epoch)
print(checkpoint.best_val_loss)
print(numParams(model))
# test function
def evaluate(net, eval_loader):
net.eval()
print('********Starting metrics evaluation on test dataset**********')
total_stoi = 0.0
total_ssnr = 0.0
total_pesq = 0.0
total_csig = 0.0
total_cbak = 0.0
total_covl = 0.0
with torch.no_grad():
count, total_eval_loss = 0, 0.0
for k, (features, labels) in enumerate(eval_loader):
features = features.cuda() # [1, 1, num_frames,frame_size]
labels = labels.cuda() # [signal_len, ]
output = net(features) # [1, 1, sig_len_recover]
output = output.squeeze() # [sig_len_recover, ]
# keep length same (output label)
output = output[:labels.shape[-1]]
eval_loss = torch.mean((output - labels) ** 2)
total_eval_loss += eval_loss.data.item()
est_sp = output.cpu().numpy()
cln_raw = labels.cpu().numpy()
eval_metric = eval_composite(cln_raw, est_sp, sr)
#st = get_stoi(cln_raw, est_sp, sr)
#pe = get_pesq(cln_raw, est_sp, sr)
#sn = snr(cln_raw, est_sp)
total_pesq += eval_metric['pesq']
total_ssnr += eval_metric['ssnr']
total_stoi += eval_metric['stoi']
total_cbak += eval_metric['cbak']
total_csig += eval_metric['csig']
total_covl += eval_metric['covl']
wavfile.write(os.path.join(audio_file_save, os.path.basename(file_list[k])), sr, est_sp.astype(np.float32))
count += 1
avg_eval_loss = total_eval_loss / count
return avg_eval_loss, total_stoi / count, total_pesq / count, total_ssnr / count, total_csig / count, total_cbak / count, total_covl / count
def eva_noisy(file_path):
print('********Starting metrics evaluation on raw noisy data**********')
total_stoi = 0.0
total_ssnr = 0.0
total_pesq = 0.0
total_csig = 0.0
total_cbak = 0.0
total_covl = 0.0
count = 0
with open(file_path, 'r') as eva_file_list:
file_list = [line.strip() for line in eva_file_list.readlines()]
for i in range(len(file_list)):
filename = file_list[i]
reader = h5py.File(filename, 'r')
noisy_raw = reader['noisy_raw'][:]
cln_raw = reader['clean_raw'][:]
eval_metric = eval_composite(cln_raw, noisy_raw, sr)
total_pesq += eval_metric['pesq']
total_ssnr += eval_metric['ssnr']
total_stoi += eval_metric['stoi']
total_cbak += eval_metric['cbak']
total_csig += eval_metric['csig']
total_covl += eval_metric['covl']
count += 1
return total_stoi / count, total_pesq / count, total_ssnr / count, total_cbak / count, total_csig / count, total_covl / count
avg_eval, avg_stoi, avg_pesq, avg_ssnr, avg_csig, avg_cbak, avg_covl = evaluate(model, test_loader)
#avg_stoi, avg_pesq, avg_ssnr, avg_cbak, avg_csig, avg_covl = eva_noisy(test_file_list_path)
#print('Avg_loss: {:.4f}'.format(avg_eval))
print('STOI: {:.4f}'.format(avg_stoi))
print('SSNR: {:.4f}'.format(avg_ssnr))
print('PESQ: {:.4f}'.format(avg_pesq))
print('CSIG: {:.4f}'.format(avg_csig))
print('CBAK: {:.4f}'.format(avg_cbak))
print('COVL: {:.4f}'.format(avg_covl))
|
|