def get_energies(self, y: tf.Tensor, weights_in_time: tf.TensorArray):
weight_sum = tf.cond(
tf.greater(weights_in_time.size(), 0),
lambda: tf.reduce_sum(weights_in_time.stack(), axis=0),
lambda: 0.0)
coverage = weight_sum / self.fertility * self.attention_mask
logits = tf.reduce_sum(
self.similarity_bias_vector * tf.tanh(
self.hidden_features + y + self.coverage_weights *
tf.expand_dims(tf.expand_dims(coverage, -1), -1)),
[2, 3])
return logits
评论列表
文章目录