generator_ll.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:WaveNet-Enhancement 作者: auspicious3000 项目源码 文件源码
def run_semi_online(self, sess, inputs_clean, inputs_noisy, num_samples):
        dump = sess.run(self.init_ops, 
                        feed_dict={self.history_clean: inputs_clean[:,0:self.len_pad+1]})
        skips_noisy_sum = sess.run(self.skips_noisy_sum, 
                                   feed_dict={self.inputs_noisy: inputs_noisy})
        indices = inputs_clean[:,self.len_pad:self.len_pad+1]
        predictions_ = []
        for step in xrange(num_samples):
            #indices = inputs_clean[:,self.len_pad+step:self.len_pad+1+step]
            feed_dict = feed_dict={self.inputs_clean: indices,
                                   self.skips_noisy: skips_noisy_sum[:,:,step]}
            output_dist = sess.run(self.out_ops, feed_dict=feed_dict)[0]
            #indices = np.argmax(output_dist, axis=1)[:,None]
            #inputs = self.bins_center[indices[:,0]].astype(np.float32)
            inputs = np.matmul(output_dist, self.bins_center).astype(np.float32)
            indices = np.digitize(inputs, self.bins_edge, right=False)[:,None]
            predictions_.append(indices)

        predictions = np.concatenate(predictions_, axis=1)
        dump = sess.run(self.dequ_ops)

        return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号