def generate_sequence_output(encoder_outputs,
encoder_state,
num_decoder_symbols,
sequence_length,
num_heads=1,
dtype=dtypes.float32,
use_attention=True,
loop_function=None,
scope=None,
DNN_at_output=False,
forward_only=False):
with variable_scope.variable_scope(scope or "non-attention_RNN"):
attention_encoder_outputs = list()
sequence_attention_weights = list()
# copy over logits once out of sequence_length
if encoder_outputs[0].get_shape().ndims != 1:
(fixed_batch_size, output_size) = encoder_outputs[0].get_shape().with_rank(2)
else:
fixed_batch_size = encoder_outputs[0].get_shape().with_rank_at_least(1)[0]
if fixed_batch_size.value:
batch_size = fixed_batch_size.value
else:
batch_size = array_ops.shape(encoder_outputs[0])[0]
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length is not None: # Prepare variables
zero_logit = array_ops.zeros(
array_ops.pack([batch_size, num_decoder_symbols]), encoder_outputs[0].dtype)
zero_logit.set_shape(
tensor_shape.TensorShape([fixed_batch_size.value, num_decoder_symbols]))
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
for time, input_ in enumerate(encoder_outputs):
if time > 0: variable_scope.get_variable_scope().reuse_variables()
if not DNN_at_output:
generate_logit = lambda: linear_transformation(encoder_outputs[time], output_size, num_decoder_symbols)
else:
generate_logit = lambda: multilayer_perceptron(encoder_outputs[time], output_size, 200, num_decoder_symbols, forward_only=forward_only)
# pylint: enable=cell-var-from-loop
if sequence_length is not None:
logit = _step(
time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit)
else:
logit = generate_logit
attention_encoder_outputs.append(logit)
if DNN_at_output:
regularizers = get_multilayer_perceptron_regularizers()
else:
regularizers = get_linear_transformation_regularizers()
return attention_encoder_outputs, sequence_attention_weights, regularizers
评论列表
文章目录