def get_output_for(self, input, **kwargs):
ndim = input.ndim
reshaped_param = TT.reshape(self.param, (1,) * (ndim - 1) + (self.num_units,))
tile_arg = TT.concatenate([input.shape[:-1], [1]])
tiled = TT.tile(reshaped_param, tile_arg, ndim=ndim)
return tiled
评论列表
文章目录