def run_semi_online(self, sess, inputs_clean, inputs_noisy, num_samples):
dump = sess.run(self.init_ops,
feed_dict={self.history_clean: inputs_clean[:,0:self.len_pad+1]})
skips_noisy_sum = sess.run(self.skips_noisy_sum,
feed_dict={self.inputs_noisy: inputs_noisy})
indices = inputs_clean[:,self.len_pad:self.len_pad+1]
predictions_ = []
for step in xrange(num_samples):
#indices = inputs_clean[:,self.len_pad+step:self.len_pad+1+step]
feed_dict = feed_dict={self.inputs_clean: indices,
self.skips_noisy: skips_noisy_sum[:,:,step]}
output_dist = sess.run(self.out_ops, feed_dict=feed_dict)[0]
#indices = np.argmax(output_dist, axis=1)[:,None]
#inputs = self.bins_center[indices[:,0]].astype(np.float32)
inputs = np.matmul(output_dist, self.bins_center).astype(np.float32)
indices = np.digitize(inputs, self.bins_edge, right=False)[:,None]
predictions_.append(indices)
predictions = np.concatenate(predictions_, axis=1)
dump = sess.run(self.dequ_ops)
return predictions
generator_ll.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录