def get_output_for(self, input, **kwargs):
input_shape = tf.shape(input)
n_batches = input_shape[0]
state = tf.tile(
tf.reshape(self.h0, (1, self.num_units)),
(n_batches, 1)
)
state.set_shape((None, self.num_units))
if self.horizon is not None:
outputs = []
for idx in range(self.horizon):
output, state = self.gru(
input[:, idx, :], state, scope=self.scope) # self.name)
outputs.append(tf.expand_dims(output, 1))
outputs = tf.concat(1, outputs)
return outputs
else:
n_steps = input_shape[1]
input = tf.reshape(input, tf.pack([n_batches, n_steps, -1]))
# flatten extra dimensions
shuffled_input = tf.transpose(input, (1, 0, 2))
shuffled_input.set_shape((None, None, self.input_shape[-1]))
hs = tf.scan(
self.step,
elems=shuffled_input,
initializer=state
)
shuffled_hs = tf.transpose(hs, (1, 0, 2))
return shuffled_hs
评论列表
文章目录