def sample_x_y_gumbel(self, x, temperature=10, test=False):
x = self.to_variable(x)
mean, ln_var = self.q_a_x(x, test=test)
a = F.gaussian(mean, ln_var)
return self.sample_ax_y_gumbel(a, x, temperature=temperature, test=test)
评论列表
文章目录