def get_output_for(self, input, **kwargs):
# if the input has more than two dimensions, flatten it into a
# batch of feature vectors.
input_reshape = input.flatten(2) if input.ndim > 2 else input
activation = T.dot(input_reshape, self.W_h)
if self.b_h is not None:
activation = activation + self.b_h.dimshuffle('x', 0)
activation = self.nonlinearity(activation)
transform = T.dot(input_reshape, self.W_t)
if self.b_t is not None:
transform = transform + self.b_t.dimshuffle('x', 0)
transform = nonlinearities.sigmoid(transform)
carry = 1.0 - transform
output = activation * transform + input_reshape * carry
# reshape output back to orignal input_shape
if input.ndim > 2:
output = T.reshape(output, input.shape)
return output
评论列表
文章目录