def test_argmax_and_embed():
"""Ensure argmax_and_embed works without projection"""
embedding = tf.get_variable('embedding', [3, 20])
data = tf.get_variable('input', initializer=np.array([[1., 2., 1.]]))
loop_fn = helpers.argmax_and_embed(embedding, output_projection=None)
correct = tf.nn.embedding_lookup(embedding, [1])
result = loop_fn(data, 0)
# get ready to see if it's right
sess = tf.get_default_session()
sess.run(tf.initialize_all_variables())
a, b = sess.run([result, correct])
assert np.all(a == b)
评论列表
文章目录