def forward(self, x):
"""
A model for non-linear data that works off of mixing multiple Gaussian
distributions together. Uses linear projections of a given input to generate
a set of N Gaussian models' mixture components, means and standard deviations.
:param x: (num. samples, input dim.)
:return: Mixture components, means, and standard deviations
in the form (num. samples, num. mixtures)
"""
x = F.tanh(self.projection(x))
weights = F.softmax(self.weights_projection(x))
means = self.mean_projection(x)
stds = torch.exp(self.std_projection(x))
return weights, means, stds
评论列表
文章目录