nade_test.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def testExternalBias(self):
    batch_size = 4
    num_hidden = 6
    num_dims = 8
    test_inputs = tf.random_normal(shape=(batch_size, num_dims))
    test_b_enc = tf.random_normal(shape=(batch_size, num_hidden))
    test_b_dec = tf.random_normal(shape=(batch_size, num_dims))

    nade = Nade(num_dims, num_hidden)
    log_prob, cond_probs = nade.log_prob(test_inputs, test_b_enc, test_b_dec)
    sample, sample_prob = nade.sample(b_enc=test_b_enc, b_dec=test_b_dec)
    with self.test_session() as sess:
      sess.run([tf.global_variables_initializer()])
      self.assertEqual(log_prob.eval().shape, (batch_size,))
      self.assertEqual(cond_probs.eval().shape, (batch_size, num_dims))
      self.assertEqual(sample.eval().shape, (batch_size, num_dims))
      self.assertEqual(sample_prob.eval().shape, (batch_size,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号