def rnn_decoder(decoder_inputs, initial_state, cell, scope=None):
"""RNN Decoder that creates training and sampling sub-graphs.
Args:
decoder_inputs: Inputs for decoder, list of tensors.
This is used only in training sub-graph.
initial_state: Initial state for the decoder.
cell: RNN cell to use for decoder.
scope: Scope to use, if None new will be produced.
Returns:
List of tensors for outputs and states for training and sampling sub-graphs.
"""
with vs.variable_scope(scope or "dnn_decoder"):
states, sampling_states = [initial_state], [initial_state]
outputs, sampling_outputs = [], []
with ops.name_scope("training", values=[decoder_inputs, initial_state]):
for i, inp in enumerate(decoder_inputs):
if i > 0:
vs.get_variable_scope().reuse_variables()
output, new_state = cell(inp, states[-1])
outputs.append(output)
states.append(new_state)
with ops.name_scope("sampling", values=[initial_state]):
for i, _ in enumerate(decoder_inputs):
if i == 0:
sampling_outputs.append(outputs[i])
sampling_states.append(states[i])
else:
sampling_output, sampling_state = cell(sampling_outputs[-1],
sampling_states[-1])
sampling_outputs.append(sampling_output)
sampling_states.append(sampling_state)
return outputs, states, sampling_outputs, sampling_states
评论列表
文章目录