def plot_data_and_D(sess, model, FLAGS):
# True data distribution with untrained D
f, ax = plt.subplots(1)
# p_data
X = np.linspace(int(FLAGS.mu-3.0*FLAGS.sigma),
int(FLAGS.mu+3.0*FLAGS.sigma),
FLAGS.num_points)
y = norm.pdf(X, loc=FLAGS.mu, scale=FLAGS.sigma)
ax.plot(X, y, label='p_data')
# Untrained p_discriminator
untrained_D = np.zeros((FLAGS.num_points,1))
for i in range(FLAGS.num_points/FLAGS.batch_size):
batch_X = np.reshape(
X[FLAGS.batch_size*i:FLAGS.batch_size*(i+1)],
(FLAGS.batch_size,1))
untrained_D[FLAGS.batch_size*i:FLAGS.batch_size*(i+1)] = \
sess.run(model.D,
feed_dict={model.pretrained_inputs: batch_X})
ax.plot(X, untrained_D, label='untrained_D')
plt.legend()
plt.show()
minibatch_discrimination.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录