def reinforce_baseline(decoder_states, reward):
"""
Center the reward by computing a baseline reward over decoder states.
:param decoder_states: internal states of the decoder, tensor of shape (batch_size, time_steps, state_size)
:param reward: reward for each time step, tensor of shape (batch_size, time_steps)
:return: reward - computed baseline, tensor of shape (batch_size, time_steps)
"""
# batch_size = tf.shape(decoder_states)[0]
# time_steps = tf.shape(decoder_states)[1]
# state_size = decoder_states.get_shape()[2]
# states = tf.reshape(decoder_states, shape=tf.stack([batch_size * time_steps, state_size]))
baseline = dense(tf.stop_gradient(decoder_states), units=1, activation=None, name='reward_baseline',
kernel_initializer=tf.constant_initializer(0.01))
baseline = tf.squeeze(baseline, axis=2)
# baseline = tf.reshape(baseline, shape=tf.stack([batch_size, time_steps]))
return reward - baseline
评论列表
文章目录