def alpha(self, inputs, state=None, u=None, buffer=None, reuse=None, init_buffer=False, name='alpha'):
"""The dynamics parameter network alpha for mixing transitions in a state space model.
This function is quite general and supports different architectures (NN, RNN, FIFO queue, learning the inputs)
Args:
inputs: tensor to condition mixing vector on
state: previous state if using RNN network to model alpha
u: pass-through variable if u is given (learn_u=False)
buffer: buffer for the FIFO network (used for fifo_size>1)
reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
init_buffer: initialize buffer for a_t
name: name of the scope
Returns:
alpha: mixing vector of dimension (batch size, K)
state: new state
u: either inferred u from model or pass-through
buffer: FIFO buffer
"""
# Increase the number of hidden units if we also learn u (learn_u=True)
num_units = self.config.alpha_units * 2 if self.config.learn_u else self.config.alpha_units
# Overwrite input buffer
if init_buffer:
buffer = tf.zeros((tf.shape(inputs)[0], self.config.dim_a, self.config.fifo_size), dtype=tf.float32)
# If K == 1, return inputs
if self.config.K == 1:
return tf.ones([self.config.batch_size, self.config.K]), state, u, buffer
with tf.variable_scope(name, reuse=reuse):
if self.config.alpha_rnn:
rnn_cell = BasicLSTMCell(num_units, reuse=reuse)
output, state = rnn_cell(inputs, state)
else:
# Shift buffer
buffer = tf.concat([buffer[:, :, 1:], tf.expand_dims(inputs, 2)], 2)
output = slim.repeat(
tf.reshape(buffer, (tf.shape(inputs)[0], self.config.dim_a * self.config.fifo_size)),
self.config.alpha_layers, slim.fully_connected, num_units,
get_activation_fn(self.config.alpha_activation), scope='hidden')
# Get Alpha as the first part of the output
alpha = slim.fully_connected(output[:, :self.config.alpha_units],
self.config.K,
activation_fn=tf.nn.softmax,
scope='alpha_var')
if self.config.learn_u:
# Get U as the second half of the output
u = slim.fully_connected(output[:, self.config.alpha_units:],
self.config.dim_u, activation_fn=None, scope='u_var')
return alpha, state, u, buffer
评论列表
文章目录