def plot_examples(data_loader, model, epoch, plotter, ind = [0, 10, 20]):
# switch to evaluate mode
model.eval()
for i, (g, h, e, target) in enumerate(data_loader):
if i in ind:
subfolder_path = 'batch_' + str(i) + '_t_' + str(int(target[0][0])) + '/epoch_' + str(epoch) + '/'
if not os.path.isdir(args.plotPath + subfolder_path):
os.makedirs(args.plotPath + subfolder_path)
num_nodes = torch.sum(torch.sum(torch.abs(h[0, :, :]), 1) > 0)
am = g[0, 0:num_nodes, 0:num_nodes].numpy()
pos = h[0, 0:num_nodes, :].numpy()
plotter.plot_graph(am, position=pos, fig_name=subfolder_path+str(i) + '_input.png')
# Prepare input data
if args.cuda:
g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda()
g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target)
# Compute output
model(g, h, e, lambda cls, id: plotter.plot_graph(am, position=pos, cls=cls,
fig_name=subfolder_path+ id))
评论列表
文章目录