toy.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-geometric-gan 作者: lim0606 项目源码 文件源码
def save_contour(netD, filename, cuda=False):
    #import warnings
    #warnings.filterwarnings("ignore", category=FutureWarning)
    #import numpy as np
    #import matplotlib
    #matplotlib.use('Agg')
    #import matplotlib.cm as cm
    #import matplotlib.mlab as mlab
    #import matplotlib.pyplot as plt

    matplotlib.rcParams['xtick.direction'] = 'out'
    matplotlib.rcParams['ytick.direction'] = 'out'
    matplotlib.rcParams['contour.negative_linestyle'] = 'solid' 

    # gen grid 
    delta = 0.1
    x = np.arange(-25.0, 25.0, delta)
    y = np.arange(-25.0, 25.0, delta)
    X, Y = np.meshgrid(x, y)

    # convert numpy array to to torch variable
    (h, w) = X.shape
    XY = np.concatenate((X.reshape((h*w, 1, 1, 1)), Y.reshape((h*w, 1, 1, 1))), axis=1)
    input = torch.Tensor(XY)
    input = Variable(input)
    if cuda:
        input = input.cuda()

    # forward
    output = netD(input)

    # convert torch variable to numpy array
    Z = output.data.cpu().view(-1).numpy().reshape(h, w)

    # plot and save 
    plt.figure()
    CS1 = plt.contourf(X, Y, Z)
    CS2 = plt.contour(X, Y, Z, alpha=.7, colors='k')
    plt.clabel(CS2, inline=1, fontsize=10, colors='k')
    plt.title('Simplest default with labels')
    plt.savefig(filename)
    plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号