def apply(self, is_train, context_embed, answer, context_mask=None):
init_fn = get_keras_initialization(self.init)
with tf.variable_scope("bounds_encoding"):
m1, m2 = self.predictor.apply(is_train, context_embed, context_mask)
with tf.variable_scope("start_pred"):
logits1 = fully_connected(m1, 1, activation_fn=None,
weights_initializer=init_fn)
logits1 = tf.squeeze(logits1, squeeze_dims=[2])
with tf.variable_scope("end_pred"):
logits2 = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn)
logits2 = tf.squeeze(logits2, squeeze_dims=[2])
with tf.variable_scope("predict_span"):
return self.span_predictor.predict(answer, logits1, logits2, context_mask)
评论列表
文章目录