def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, num_encoder_symbols,
num_decoder_symbols, embedding_size,
output_projection=None,
feed_previous=False,
dtype=None,
scope=None,
beam_search=True,
beam_size=10):
with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
if dtype is not None:
scope.set_dtype(dtype)
else:
dtype = scope.dtype
# Encoder.
encoder_cell = rnn_cell.EmbeddingWrapper(
cell, embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
_, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
return embedding_rnn_decoder(decoder_inputs, encoder_state, cell, num_decoder_symbols, embedding_size,
output_projection=output_projection,
feed_previous=feed_previous,
scope=scope,
beam_search=beam_search,
beam_size=beam_size)
# If feed_previous is a Tensor, we construct 2 graphs and use cond.
def decoder(feed_previous_bool):
reuse = None if feed_previous_bool else True
with variable_scope.variable_scope(
variable_scope.get_variable_scope(), reuse=reuse) as scope:
outputs, state = embedding_rnn_decoder(
decoder_inputs, encoder_state, cell, num_decoder_symbols,
embedding_size, output_projection=output_projection,
feed_previous=feed_previous_bool,
update_embedding_for_previous=False,
beam_search=beam_search,
beam_size=beam_size)
state_list = [state]
if nest.is_sequence(state):
state_list = nest.flatten(state)
return outputs + state_list
outputs_and_state = control_flow_ops.cond(feed_previous,
lambda: decoder(True),
lambda: decoder(False))
outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
state_list = outputs_and_state[outputs_len:]
state = state_list[0]
if nest.is_sequence(encoder_state):
state = nest.pack_sequence_as(structure=encoder_state,
flat_sequence=state_list)
return outputs_and_state[:outputs_len], state
grl_seq2seq.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录