test_cnn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:CNN_denoise 作者: weedwind 项目源码 文件源码
def gen_post(feat_list, stat_file, model, win_size_before = 15, win_size_after = 15, num_targets = 31):
   model.eval()             # Put the model in test mode (the opposite of model.train(), essentially)

   m, v = read_mv(stat_file)
   if m is None or v is None:
      raise Exception("mean or variance vector does not exist")

   with open(feat_list) as f:
      for line in f:
         line = line.strip()
         if len(line) < 1: continue
         print ("generating features for file", line)
         io = htk_io.fopen(line)
         utt_feat = io.getall()
         utt_feat -= m       # normalize mean
         utt_feat /= (np.sqrt(v) + eps)     # normalize var
         feat_numpy = org_data(utt_feat, win_size_before, win_size_after)
         out_feat = np.zeros((utt_feat.shape[0], num_targets))
         for i in range(feat_numpy.shape[0] // 100):     # chop the speech into shorter segments, to prevent gpu out of memory
             start_idx = i * 100
             end_idx = i * 100 + 100
             feat_chunk = feat_numpy[start_idx:end_idx]
             feat_tensor = torch.from_numpy(feat_chunk).type(gpu_dtype)
             x = Variable(feat_tensor.type(gpu_dtype), volatile = True)
             scores = model(x)
             out_feat[start_idx:end_idx] = scores.data.cpu().numpy()
         num_remain = feat_numpy.shape[0] % 100
         if num_remain > 0:
            feat_chunk = feat_numpy[-num_remain:]
            feat_tensor = torch.from_numpy(feat_chunk).type(gpu_dtype)
            x = Variable(feat_tensor.type(gpu_dtype), volatile = True)
            scores = model(x)
            out_feat[-num_remain:] = scores.data.cpu().numpy()

         out_feat = dct(out_feat, type=2, axis=1, norm='ortho')[:,1:numcep+1]
         out_feat_delta = delta(out_feat, 2)
         out_feat_ddelta = delta(out_feat_delta, 2)
         out_feat = np.concatenate((out_feat, out_feat_delta, out_feat_ddelta), axis = 1)   

         out_file = line.replace(".fea", ".mfc")
         io = htk_io.fopen(out_file, mode="wb", veclen = out_feat.shape[1])
         io.writeall(out_feat)
         print ("features saved in %s\n" %out_file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号