def compile(self, session=None):
"""
Before calling the standard compile function, check to see if the size
of the data has changed and add variational parameters appropriately.
This is necessary because the shape of the parameters depends on the
shape of the data.
"""
if not self.num_data == self.X.shape[0]:
self.num_data = self.X.shape[0]
self.q_mu = Parameter(np.zeros((self.num_data, self.num_latent)))
self.q_sqrt = Parameter(np.eye(self.num_data)[:, :, None] *
np.ones((1, 1, self.num_latent)))
return super(VGP, self).compile(session=session)
评论列表
文章目录