def constrain_logits(self, logits, curr_state):
with tf.name_scope('constrain_logits'):
allowed_tokens = tf.gather(tf.constant(self.allowed_token_matrix), curr_state)
assert allowed_tokens.get_shape()[1:] == (self.output_size,)
constrained_logits = tf.where(allowed_tokens, logits, tf.fill(tf.shape(allowed_tokens), -1e+10))
return constrained_logits
thingtalk.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录