def traditional_transition_loss_pred(self, i, j, combined_head, combined_dep):
rel_trans_feat_ids = self.trans_feat_ids[i*self.args.beam_size+j] if not self.train else self.trans_feat_ids[i, j]
rel_head = tf.reshape(tf.gather(combined_head, rel_trans_feat_ids[:4]), [4, self.args.rel_emb_dim])
rel_dep = tf.reshape(tf.gather(combined_dep, rel_trans_feat_ids[:4]), [4, self.args.rel_emb_dim])
mask = tf.cast(tf.reshape(tf.greater_equal(rel_trans_feat_ids[:4], 0), [4,1]), tf.float32)
rel_head = tf.multiply(mask, rel_head)
rel_dep = tf.multiply(mask, rel_dep)
rel_hid = self.rel_merge(rel_head, rel_dep)
rel_logit = self.rel_dense(tf.reshape(rel_hid, [1, -1]))
rel_logit = tf.reshape(rel_logit, [-1])
log_partition = tf.reduce_logsumexp(rel_logit)
if self.train:
res = log_partition - rel_logit[self.trans_labels[i, j]]
return res
else:
arc_pred = log_partition - rel_logit
return arc_pred
评论列表
文章目录