def sequence_loss_by_mle(logits, targets, vocab_size, sequence_length, batch_size, output_projection=None):
#print("logits: ", np.shape(logits[0]))
#logits: [seq_len, batch_size, emb_dim]
#targets: [seq_len, batch_size] =====transpose====> [batch_size, seq_len]
# labels = tf.to_int32(tf.transpose(targets))
#targets: [seq_len, batch_size] ====reshape[-1]====> [seq_len * batch_size]
labels = tf.to_int32(tf.reshape(targets, [-1]))
if output_projection is not None:
#logits = nn_ops.xw_plus_b(logits, output_projection[0], output_projection[1])
logits = [tf.matmul(logit, output_projection[0]) + output_projection[1] for logit in logits]
reshape_logits = tf.reshape(logits, [-1, vocab_size]) #[seq_len * batch_size, vocab_size]
prediction = tf.clip_by_value(reshape_logits, 1e-20, 1.0)
pretrain_loss = -tf.reduce_sum(
# [seq_len * batch_size , vocab_size]
tf.one_hot(labels, vocab_size, 1.0, 0.0) * tf.log(prediction)
) / (sequence_length * batch_size)
return pretrain_loss
评论列表
文章目录