models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: eske 项目源码 文件源码
def reinforce_baseline(decoder_states, reward):
    """
    Center the reward by computing a baseline reward over decoder states.

    :param decoder_states: internal states of the decoder, tensor of shape (batch_size, time_steps, state_size)
    :param reward: reward for each time step, tensor of shape (batch_size, time_steps)
    :return: reward - computed baseline, tensor of shape (batch_size, time_steps)
    """
    # batch_size = tf.shape(decoder_states)[0]
    # time_steps = tf.shape(decoder_states)[1]
    # state_size = decoder_states.get_shape()[2]
    # states = tf.reshape(decoder_states, shape=tf.stack([batch_size * time_steps, state_size]))

    baseline = dense(tf.stop_gradient(decoder_states), units=1, activation=None, name='reward_baseline',
                     kernel_initializer=tf.constant_initializer(0.01))
    baseline = tf.squeeze(baseline, axis=2)

    # baseline = tf.reshape(baseline, shape=tf.stack([batch_size, time_steps]))
    return reward - baseline
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号