def __call__(self, inputs, state, scope=None):
"""Variational recurrent neural network cell (VRNN)."""
with tf.variable_scope(scope or type(self).__name__):
# Update the hidden state.
z_t, z_mean_t, z_log_sigma_sq_t = state
h_t_1 = self._activation(_linear(
[inputs, z_t, z_mean_t, z_log_sigma_sq_t],
2 * self._num_units,
True))
z_mean_t_1, z_log_sigma_sq_t_1 = tf.split(1, 2, h_t_1)
# Sample.
eps = tf.random_normal((tf.shape(inputs)[0], self._num_units), 0.0, 1.0,
dtype=tf.float32)
z_t_1 = tf.add(z_mean_t_1, tf.mul(tf.sqrt(tf.exp(z_log_sigma_sq_t_1)),
eps))
return z_t_1, VRNNStateTuple(z_t_1, z_mean_t_1, z_log_sigma_sq_t_1)
评论列表
文章目录