demo_letter_duvenaud.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:nmp_qc 作者: priba 项目源码 文件源码
def plot_examples(data_loader, model, epoch, plotter, ind = [0, 10, 20]):

    # switch to evaluate mode
    model.eval()

    for i, (g, h, e, target) in enumerate(data_loader):
        if i in ind:
            subfolder_path = 'batch_' + str(i) + '_t_' + str(int(target[0][0])) + '/epoch_' + str(epoch) + '/'
            if not os.path.isdir(args.plotPath + subfolder_path):
                os.makedirs(args.plotPath + subfolder_path)

            num_nodes = torch.sum(torch.sum(torch.abs(h[0, :, :]), 1) > 0)
            am = g[0, 0:num_nodes, 0:num_nodes].numpy()
            pos = h[0, 0:num_nodes, :].numpy()

            plotter.plot_graph(am, position=pos, fig_name=subfolder_path+str(i) + '_input.png')

            # Prepare input data
            if args.cuda:
                g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda()
            g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target)

            # Compute output
            model(g, h, e, lambda cls, id: plotter.plot_graph(am, position=pos, cls=cls,
                                                          fig_name=subfolder_path+ id))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号