test_helpers.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:generating_sequences 作者: PFCM 项目源码 文件源码
def test_sample_and_embed_with_projection():
    """Ensure sample_and_embed works with projection"""
    embedding = tf.get_variable('embedding', [10, 11])
    proj = (tf.get_variable('weights', [3, 10]),
            tf.get_variable('biases', [10]))
    data = tf.get_variable('input', initializer=np.array([[1., 2., 1.]],
                                                         dtype=np.float32))

    loop_fn = helpers.sample_and_embed(embedding, 1., output_projection=proj)
    result = loop_fn(data, 0)

    # get ready to see if does indeed pick out one item
    sess = tf.get_default_session()
    sess.run(tf.initialize_all_variables())

    a, embed_mat = sess.run([result, embedding])

    found = False
    for row in embed_mat:
        if np.all(row == a):
            found = True

    assert found
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号