def _extract_argmax_and_embed(embedding, DNN_at_output, output_projection, forward_only=False, update_embedding=True):
def loop_function(prev, _):
if DNN_at_output is True:
prev = multilayer_perceptron_with_initialized_W(prev, output_projection, forward_only=forward_only)
else:
prev = linear_transformation_with_initialized_W(prev, output_projection, forward_only=forward_only)
prev_symbol = math_ops.argmax(prev, 1)
# Note that gradients will not propagate through the second parameter of
# embedding_lookup.
emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
return prev, emb_prev
return loop_function
评论列表
文章目录