def test_sample_and_embed():
"""Ensure sample_and_embed works without projection"""
embedding = tf.get_variable('embedding', [3, 20])
data = tf.get_variable('input', initializer=np.array([[1., 2., 1.]]))
loop_fn = helpers.sample_and_embed(embedding, 1., output_projection=None)
result = loop_fn(data, 0)
# get ready to see if does indeed pick out one item
sess = tf.get_default_session()
sess.run(tf.initialize_all_variables())
a, embed_mat = sess.run([result, embedding])
found = False
for row in embed_mat:
if np.all(row == a):
found = True
assert found
评论列表
文章目录