def get_output_for(self, input, **kwargs):
n_batches = input.shape[0]
n_steps = input.shape[1]
input = TT.reshape(input, (n_batches, n_steps, -1))
h0s = TT.tile(TT.reshape(self.h0, (1, self.num_units)), (n_batches, 1))
# flatten extra dimensions
shuffled_input = input.dimshuffle(1, 0, 2)
hs, _ = theano.scan(fn=self.step, sequences=[shuffled_input], outputs_info=h0s)
shuffled_hs = hs.dimshuffle(1, 0, 2)
return shuffled_hs
评论列表
文章目录