test_helpers.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:generating_sequences 作者: PFCM 项目源码 文件源码
def test_argmax_and_embed():
    """Ensure argmax_and_embed works without projection"""
    embedding = tf.get_variable('embedding', [3, 20])
    data = tf.get_variable('input', initializer=np.array([[1., 2., 1.]]))

    loop_fn = helpers.argmax_and_embed(embedding, output_projection=None)
    correct = tf.nn.embedding_lookup(embedding, [1])

    result = loop_fn(data, 0)

    # get ready to see if it's right
    sess = tf.get_default_session()
    sess.run(tf.initialize_all_variables())

    a, b = sess.run([result, correct])

    assert np.all(a == b)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号