def pos_loss_pred(self, i, pos_embeddings, pos_logit, NUM_POS, gold_pos, pos_trainables):
if self.args.no_pos:
pos_emb = tf.nn.embedding_lookup(pos_embeddings, gold_pos[i])
if self.train:
return 0, pos_emb
else:
return tf.gather(gold_pos[i], tf.range(1, self.sent_length)), pos_emb
else:
pos_logit = pos_logit[1:]
log_partition = tf.reduce_logsumexp(pos_logit, [1])
pos_pred = tf.exp(pos_logit - tf.reshape(log_partition, (-1, 1)))
pos_emb = tf.concat([tf.reshape(tf.nn.embedding_lookup(pos_embeddings, NUM_POS), (1, -1)),
tf.matmul(pos_pred, pos_trainables)], 0)
if self.train:
loss = tf.reduce_sum(tf.gather(log_partition, tf.range(self.sent_lengths[i]-1))
- tf.gather(tf.reshape(pos_logit, [-1]),
tf.range(self.sent_lengths[i]-1) * NUM_POS
+ tf.gather(gold_pos[i], tf.range(1, self.sent_lengths[i]))))
return loss, pos_emb
else:
return tf.cast(tf.argmax(pos_pred, 1), tf.int32), pos_emb
评论列表
文章目录