def constrain_value_logits(self, logits, curr_state):
first_value_token = self.num_functions + self.num_begin_tokens + self.num_control_tokens
num_value_tokens = self.output_size - first_value_token
value_allowed_token_matrix = np.concatenate((self.allowed_token_matrix[:,:self.num_control_tokens], self.allowed_token_matrix[:,first_value_token:]), axis=1)
with tf.name_scope('constrain_logits'):
allowed_tokens = tf.gather(tf.constant(value_allowed_token_matrix), curr_state)
assert allowed_tokens.get_shape()[1:] == (self.num_control_tokens + num_value_tokens,)
constrained_logits = logits - tf.to_float(tf.logical_not(allowed_tokens)) * 1e+10
return constrained_logits
thingtalk.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录