def get_output_for(self, all_obs_var, **kwargs):
# n_batch = all_obs_var.shape[:-1]
# out = TT.tile(self.output_var, (n_batch, 1))
# out = TT.tile(self.output_var, TT.concatenate([n_batch, [1]]))
# return out
ndim = all_obs_var.ndim
reshaped_cnt = TT.reshape(self.output_var, (1,) * (ndim - 1) + self.output_var.get_value().shape)
tile_arg = TT.concatenate([all_obs_var.shape[:-1], [1]])
tiled = TT.tile(reshaped_cnt, tile_arg, ndim=ndim)
return tiled
评论列表
文章目录