test_helpers.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:generating_sequences 作者: PFCM 项目源码 文件源码
def test_argmax_and_embed_with_projection():
    """Ensure argmax_and_embed works with projection"""
    embedding = tf.get_variable('embedding', [10, 11])
    proj = (tf.get_variable('weights', [3, 10]),
            tf.get_variable('biases', [10]))
    data = tf.get_variable('input', initializer=np.array([[1., 2., 1.]],
                                                         dtype=np.float32))
    loop_fn = helpers.argmax_and_embed(embedding, output_projection=proj)

    # we don't know what the correct answer is now because it's randomly
    # projected, so let's get what we need to do it by hand
    correct_projection = tf.nn.bias_add(tf.matmul(data, proj[0]), proj[1])

    result = loop_fn(data, 0)

    # get ready to see if it's right
    sess = tf.get_default_session()
    sess.run(tf.initialize_all_variables())

    a, embedding, projection = sess.run(
        [result, embedding, correct_projection])

    argmax_p = np.argmax(projection)

    assert np.all(embedding[argmax_p] == a)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号