def embedding_attention_decoder(initial_state,
attention_states,
cell,
num_symbols,
time_steps,
batch_size,
embedding_size,
output_size=None,
output_projection=None,
feed_previous=False,
update_embedding_for_previous=True,
dtype=None,
scope=None):
if output_size is None:
output_size = cell.output_size
if output_projection is not None:
proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(
scope or "embedding_attention_decoder", dtype=dtype) as scope:
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
loop_function = tf.nn.seq2seq._extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None
return attention_decoder(
initial_state,
attention_states,
cell,
num_symbols,
time_steps,
batch_size,
output_size=output_size,
loop_function=loop_function)
评论列表
文章目录