def test(netG, opt):
assert opt.netG != ''
test_dir = opt.testdata_dir
for f in os.listdir(test_dir):
fname, ext = os.path.splitext(f)
if ext == '.cmp':
print(fname)
cmp_file = os.path.join(test_dir, f)
ac_data = read_binary_file(cmp_file, dim=47)
ac_data = torch.FloatTensor(ac_data)
noise = torch.FloatTensor(ac_data.size(0), nz)
if opt.cuda:
ac_data, noise = ac_data.cuda(), noise.cuda()
ac_data = Variable(ac_data)
noise = Variable(noise)
noise.data.normal_(0, 1)
generated_pulses = netG(noise, ac_data)
generated_pulses = generated_pulses.data.cpu().numpy()
generated_pulses = generated_pulses.reshape(ac_data.size(0), -1)
out_file = os.path.join(test_dir, fname + '.pls')
with open(out_file, 'wb') as fid:
generated_pulses.tofile(fid)
评论列表
文章目录