minibatch_discrimination.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:the-neural-perspective 作者: GokuMohandas 项目源码 文件源码
def plot_data_and_D(sess, model, FLAGS):

    # True data distribution with untrained D
    f, ax = plt.subplots(1)

    # p_data
    X = np.linspace(int(FLAGS.mu-3.0*FLAGS.sigma),
                    int(FLAGS.mu+3.0*FLAGS.sigma),
                    FLAGS.num_points)
    y = norm.pdf(X, loc=FLAGS.mu, scale=FLAGS.sigma)
    ax.plot(X, y, label='p_data')

    # Untrained p_discriminator
    untrained_D = np.zeros((FLAGS.num_points,1))
    for i in range(FLAGS.num_points/FLAGS.batch_size):
        batch_X = np.reshape(
            X[FLAGS.batch_size*i:FLAGS.batch_size*(i+1)],
            (FLAGS.batch_size,1))
        untrained_D[FLAGS.batch_size*i:FLAGS.batch_size*(i+1)] = \
            sess.run(model.D,
                feed_dict={model.pretrained_inputs: batch_X})
    ax.plot(X, untrained_D, label='untrained_D')

    plt.legend()
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号