def generate(rbm, iteration = 1, p = 0.5, n = 1):
v = torch.bernoulli((rbm.v_bias *0 + p).view(1,-1).repeat(n, 1))
for _ in range(iteration):
p_h, h = rbm.v_to_h(v)
p_v, v = rbm.h_to_v(h)
return v
utils.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录