nonlinear.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keita 作者: iwasaki-kenta 项目源码 文件源码
def forward(self, x):
        """
        A model for non-linear data that works off of mixing multiple Gaussian
        distributions together. Uses linear projections of a given input to generate
        a set of N Gaussian models' mixture components, means and standard deviations.

        :param x: (num. samples, input dim.)
        :return: Mixture components, means, and standard deviations
            in the form (num. samples, num. mixtures)
        """
        x = F.tanh(self.projection(x))

        weights = F.softmax(self.weights_projection(x))
        means = self.mean_projection(x)
        stds = torch.exp(self.std_projection(x))

        return weights, means, stds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号