def _extract_argmax_and_one_hot(one_hot_size,
output_projection=None):
"""Get a loop_function that extracts the previous symbol and build a one-hot vector for it.
Args:
one_hot_size: total size of one-hot vector.
output_projection: None or a pair (W, B). If provided, each fed previous
output will first be multiplied by W and added B.
update_embedding: Boolean; if False, the gradients will not propagate
through the embeddings.
Returns:
A loop function.
"""
def loop_function(prev, _):
if output_projection is not None:
prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
prev_symbol = math_ops.argmax(prev, 1)
# Note that gradients will not propagate through the second parameter of
# embedding_lookup.
emb_prev = tf.one_hot(prev_symbol, one_hot_size)
return emb_prev
return loop_function
评论列表
文章目录