toy.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch-geometric-gan 作者: lim0606 项目源码 文件源码
def exp6(num_data=1000):
    if num_data < 2:
        raise ValueError('num_data should be larger than 2. (num_data = {})'.format(num_data))

    center = -5 
    sigma_x = 7 
    sigma_y = 7 

    n1 = num_data 

    # init data 
    d1x = torch.FloatTensor(n1, 1)
    d1y = torch.FloatTensor(n1, 1)
    d1x.normal_(center, sigma_x)
    d1y.normal_(center, sigma_y)

    d1 = torch.cat((d1x, d1y), 1)

    d = d1 

    # label
    label = torch.IntTensor(num_data).zero_()
    label[:] = 0 

    # shuffle
    #shuffle = torch.randperm(d.size()[0])
    #d = torch.index_select(d, 0, shuffle)
    #label = torch.index_select(label, 0, shuffle)

    # pdf
    rv1 = multivariate_normal([ center,  center], [[math.pow(sigma_x, 2), 0.0], [0.0, math.pow(sigma_y, 2)]])

    def pdf(x):
        prob = (float(n1) / float(num_data)) * rv1.pdf(x)
        return prob

    def sumloglikelihood(x):
        return np.sum(np.log((pdf(x) + 1e-10)))

    return d, label, sumloglikelihood
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号