test_layers.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:aboleth 作者: data61 项目源码 文件源码
def test_dense_embeddings(make_categories, reps, layer):
    """Test the embedding layer."""
    x, K = make_categories
    x = np.repeat(x, reps, axis=-1)
    N = len(x)
    S = 3
    x_, X_ = _make_placeholders(x, S, tf.int32)
    output, reg = layer(output_dim=D, n_categories=K)(X_)

    tc = tf.test.TestCase()
    with tc.test_session():
        tf.global_variables_initializer().run()
        r = reg.eval()

        assert np.isscalar(r)
        assert r >= 0

        Phi = output.eval(feed_dict={x_: x})

        assert Phi.shape == (S, N, D * reps)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号