def calculate_mis(self, p_y_given_x, theta, Xm):
mis = np.zeros((self.n_hidden, self.n_visible))
sample = np.random.choice(np.arange(Xm.shape[0]), min(self.max_samples, Xm.shape[0]), replace=False)
n_observed = np.sum(np.logical_not(ma.getmaskarray(Xm[sample])), axis=0)
n_samples, n_visible = Xm.shape
memory_size = float(n_samples * n_visible * self.n_hidden * self.dim_hidden * 64) / 1000**3 # GB
batch_size = np.clip(int(self.ram * n_visible / memory_size), 1, n_visible)
for i in range(0, n_visible, batch_size):
log_marg_x = self.calculate_marginals_on_samples(theta[i:i+batch_size, ...], Xm[sample, i:i+batch_size]) # n_hidden, n_samples, n_visible, dim_hidden
mis[:, i:i+batch_size] = np.einsum('ijl,ijkl->ik', p_y_given_x[:, sample, :], log_marg_x) / n_observed[i:i+batch_size][np.newaxis, :]
return mis # MI in nats
评论列表
文章目录