def _extract_argmax_and_embed(embedding, output_projection=None,
update_embedding=True):
def loop_function(prev, _):
if output_projection is not None:
prev = nn_ops.xw_plus_b(
prev, output_projection[0], output_projection[1])
prev_symbol = math_ops.argmax(prev, 1)
# Note that gradients will not propagate through the second parameter of
# embedding_lookup.
emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
if not update_embedding:
emb_prev = array_ops.stop_gradient(emb_prev)
return emb_prev
return loop_function
评论列表
文章目录