utils.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:restricted-boltzmann-machine-deep-belief-network-deep-boltzmann-machine-in-pytorch 作者: wmingwei 项目源码 文件源码
def generate(dbn, iteration = 1, prop_input = None, annealed = False, n = 0):

    if not type(prop_input) == type(None):
        prop_v = Variable(torch.from_numpy(prop_input).type(torch.FloatTensor))
        for i in range(dbn.n_layers-1):
            prop_v = dbn.rbm_layers[i].v_to_h(prop_v)[0]
        prop = prop_v.data.mean()
    else:
        prop = 0.5

    h = torch.bernoulli((dbn.rbm_layers[-1].h_bias *0 + prop).view(1,-1).repeat(n, 1))
    p_v, v = dbn.rbm_layers[-1].h_to_v(h)

    if not annealed:
        for _ in range(iteration):

            p_h, h = dbn.rbm_layers[-1].v_to_h(v)

            p_v, v = dbn.rbm_layers[-1].h_to_v(h)
    else:
        for temp in np.linspace(3,0.6,25):
            for i in dbn.rbm_layers[-1].parameters():
                i.data *= 1.0/temp

            for _ in range(iteration):

                p_h, h = dbn.rbm_layers[-1].v_to_h(v)

                p_v, v = dbn.rbm_layers[-1].h_to_v(h)    

            for i in dbn.rbm_layers[-1].parameters():
                i.data *= temp

    for i in range(dbn.n_layers-1):
        p_v, v = dbn.rbm_layers[-2-i].h_to_v(v)

    return v
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号