def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
num_encoder_symbols, num_decoder_symbols,
embedding_size, output_projection=None,
feed_previous=False, dtype=dtypes.float32,
scope=None):
"""Embedding RNN sequence-to-sequence model.
"""
with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq"):
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(
cell, embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
_, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
return embedding_rnn_decoder(
decoder_inputs, encoder_state, cell, num_decoder_symbols,
embedding_size, output_projection=output_projection,
feed_previous=feed_previous)
# If feed_previous is a Tensor, we construct 2 graphs and use cond.
def decoder(feed_previous_bool):
reuse = None if feed_previous_bool else True
with variable_scope.variable_scope(variable_scope.get_variable_scope(),
reuse=reuse):
outputs, state = embedding_rnn_decoder(
decoder_inputs, encoder_state, cell, num_decoder_symbols,
embedding_size, output_projection=output_projection,
feed_previous=feed_previous_bool,
update_embedding_for_previous=False)
return outputs + [state]
outputs_and_state = control_flow_ops.cond(feed_previous,
lambda: decoder(True),
lambda: decoder(False))
return outputs_and_state[:-1], outputs_and_state[-1]
评论列表
文章目录