def get_output(self, train=False):
X = self.get_input(train)
padded_mask = self.get_padded_shuffled_mask(train, X, pad=1)
X = X.dimshuffle((1, 0, 2))
x_z = T.dot(X, self.W_z) + self.b_z
x_r = T.dot(X, self.W_r) + self.b_r
x_h = T.dot(X, self.W_h) + self.b_h
outputs, updates = theano.scan(
self._step,
sequences=[x_z, x_r, x_h, padded_mask],
outputs_info=T.unbroadcast(alloc_zeros_matrix(X.shape[1], self.output_dim), 1),
non_sequences=[self.U_z, self.U_r, self.U_h],
truncate_gradient=self.truncate_gradient
)
if self.return_sequences:
return outputs.dimshuffle((1, 0, 2))
return outputs[-1]
评论列表
文章目录