def ASw_transition_loss_pred(self, i, j, combined_head, combined_dep, transition_logit, SHIFT):
# extract relevant portions of params
rel_trans_feat_ids = self.trans_feat_ids[i*self.args.beam_size+j] if not self.train else self.trans_feat_ids[i, j]
rel_trans_feat_size = self.trans_feat_sizes[i*self.args.beam_size+j] if not self.train else self.trans_feat_sizes[i, j]
# core computations
has_shift = tf.cond(tf.equal(rel_trans_feat_ids[0, 0], SHIFT), lambda: tf.constant(1), lambda: tf.constant(0))
arc_trans_count = rel_trans_feat_size - has_shift
arc_trans_feat_ids = tf.gather(rel_trans_feat_ids, tf.range(has_shift, rel_trans_feat_size))
rel_head = tf.reshape(tf.gather(combined_head, arc_trans_feat_ids[:, 1]), [arc_trans_count, self.args.rel_emb_dim])
rel_dep = tf.reshape(tf.gather(combined_dep, arc_trans_feat_ids[:, 2]), [arc_trans_count, self.args.rel_emb_dim])
rel_hid = self.rel_merge(rel_head, rel_dep)
rel_logit = self.rel_dense(rel_hid)
arc_logit = tf.reshape(rel_logit, [-1])
def logaddexp(a, b):
mx = tf.maximum(a, b)
return tf.log(tf.exp(a-mx) + tf.exp(b-mx)) + mx
if self.train:
# compute a loss and return it
log_partition = tf.reduce_logsumexp(arc_logit)
log_partition = tf.cond(tf.greater(has_shift, 0),
lambda: logaddexp(log_partition, transition_logit[rel_trans_feat_ids[0, 3]]),
lambda: log_partition)
arc_logit = log_partition - arc_logit
res = tf.cond(tf.greater(has_shift, 0),
lambda: tf.cond(tf.greater(self.trans_labels[i, j], 0),
lambda: arc_logit[self.trans_labels[i, j]-1],
lambda: log_partition - transition_logit[rel_trans_feat_ids[0, 3]]),
lambda: arc_logit[self.trans_labels[i, j]])
return res
else:
# just return predictions
arc_logit = tf.reshape(rel_logit, [-1])
log_partition = tf.reduce_logsumexp(arc_logit)
log_partition = tf.cond(tf.greater(has_shift, 0),
lambda: logaddexp(log_partition, transition_logit[rel_trans_feat_ids[0, 3]]),
lambda: log_partition)
arc_logit = log_partition - arc_logit
arc_pred = tf.cond(tf.greater(has_shift, 0),
lambda: tf.concat([tf.reshape(log_partition - transition_logit[rel_trans_feat_ids[0, 3]], (-1,1)),
tf.reshape(arc_logit, (-1,1))], 0),
lambda: tf.reshape(arc_logit, (-1, 1)))
# correct shape
current_output_shape = has_shift + arc_trans_count * rel_logit.get_shape()[1]
arc_pred = tf.concat([arc_pred, 1e20 * tf.ones((tf.subtract(self.pred_output_size, current_output_shape), 1), dtype=tf.float32)], 0)
arc_pred = tf.reshape(arc_pred, [-1])
return arc_pred
评论列表
文章目录