def output(self):
"""Iterate through hidden states to get outputs for all"""
input_shape = tf.shape(self._input_B_T_Di)
input = tf.reshape(self._input_B_T_Di, tf.pack([input_shape[0], input_shape[1], -1]))
h0s = tf.tile(tf.reshape(self.h0, (1, self._hidden_units)), (input_shape[0], 1))
# Flatten extra dimension
shuffled_input = tf.transpose(input, (1, 0, 2))
hs = tf.scan(self.step, elems=shuffled_input, initializer=h0s)
shuffled_hs = tf.transpose(hs, (1, 0, 2))
return shuffled_hs
评论列表
文章目录