def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
embedding_size, output_projection=None,
feed_previous=False,
update_embedding_for_previous=True, scope=None):
"""RNN decoder with embedding and a pure-decoding option.
"""
if output_projection is not None:
proj_weights = ops.convert_to_tensor(output_projection[0],
dtype=dtypes.float32)
proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
proj_biases = ops.convert_to_tensor(
output_projection[1], dtype=dtypes.float32)
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None
emb_inp = (
embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs)
return rnn_decoder(emb_inp, initial_state, cell,
loop_function=loop_function)
评论列表
文章目录