seq2seq.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:seqGan_chatbot 作者: zpppy 项目源码 文件源码
def sequence_loss_by_mle(logits, targets, vocab_size, sequence_length, batch_size, output_projection=None):
    #print("logits: ", np.shape(logits[0]))
    #logits: [seq_len, batch_size, emb_dim]
    #targets: [seq_len, batch_size]  =====transpose====> [batch_size, seq_len]
    # labels = tf.to_int32(tf.transpose(targets))
    #targets: [seq_len, batch_size] ====reshape[-1]====> [seq_len * batch_size]
    labels = tf.to_int32(tf.reshape(targets, [-1]))

    if output_projection is not None:
      #logits = nn_ops.xw_plus_b(logits, output_projection[0], output_projection[1])
      logits = [tf.matmul(logit, output_projection[0]) + output_projection[1] for logit in logits]

    reshape_logits = tf.reshape(logits, [-1, vocab_size]) #[seq_len * batch_size, vocab_size]

    prediction = tf.clip_by_value(reshape_logits, 1e-20, 1.0)

    pretrain_loss = -tf.reduce_sum(
        # [seq_len * batch_size , vocab_size]
        tf.one_hot(labels, vocab_size, 1.0, 0.0) * tf.log(prediction)
    ) / (sequence_length * batch_size)
    return pretrain_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号