def _extract_argmax_and_embed(embedding, output_projection=None,
update_embedding=True):
"""Get a loop_function that extracts the previous symbol and embeds it.
Args:
embedding: embedding tensor for symbols.
output_projection: None or a pair (W, B). If provided, each fed previous
output will first be multiplied by W and added B.
update_embedding: Boolean; if False, the gradients will not propagate
through the embeddings.
Returns:
A loop function.
"""
def loop_function(prev, _):
if output_projection is not None:
prev = nn_ops.xw_plus_b(
prev, output_projection[0], output_projection[1])
prev_symbol = math_ops.argmax(prev, 1)
# Note that gradients will not propagate through the second parameter of
# embedding_lookup.
emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
if not update_embedding:
emb_prev = array_ops.stop_gradient(emb_prev)
return emb_prev
return loop_function
评论列表
文章目录