def transition(self, curr_state, next_symbols, batch_size):
with tf.name_scope('grammar_transition'):
transitions = tf.gather(tf.constant(self.transition_matrix), curr_state)
assert transitions.get_shape()[1:] == (self.output_size,)
indices = tf.stack((tf.range(0, batch_size), next_symbols), axis=1)
next_state = tf.gather_nd(transitions, indices)
return next_state
thingtalk.py 文件源码
python
阅读 81
收藏 0
点赞 0
评论 0
评论列表
文章目录