def test_argmax_and_embed_with_projection():
"""Ensure argmax_and_embed works with projection"""
embedding = tf.get_variable('embedding', [10, 11])
proj = (tf.get_variable('weights', [3, 10]),
tf.get_variable('biases', [10]))
data = tf.get_variable('input', initializer=np.array([[1., 2., 1.]],
dtype=np.float32))
loop_fn = helpers.argmax_and_embed(embedding, output_projection=proj)
# we don't know what the correct answer is now because it's randomly
# projected, so let's get what we need to do it by hand
correct_projection = tf.nn.bias_add(tf.matmul(data, proj[0]), proj[1])
result = loop_fn(data, 0)
# get ready to see if it's right
sess = tf.get_default_session()
sess.run(tf.initialize_all_variables())
a, embedding, projection = sess.run(
[result, embedding, correct_projection])
argmax_p = np.argmax(projection)
assert np.all(embedding[argmax_p] == a)
评论列表
文章目录