def span_score_logits(spans, spans_mask):
w_a = tf.Variable(tf.random_normal([n_hidden]))
h_a = FFNN(spans, spans_mask, 'spans')
s_a = tf.tensordot(h_a, w_a, axes=[[-1],[-1]])
return s_a * spans_mask[:, :, 0]
评论列表
文章目录