def seq_labeling_decoder_linear(decoder_inputs, num_decoder_symbols,
scope=None, sequence_length=None, dtype=tf.float32):
with tf.variable_scope(scope or "non-attention_RNN"):
decoder_outputs = list()
# copy over logits once out of sequence_length
if decoder_inputs[0].get_shape().ndims != 1:
(fixed_batch_size, output_size) = decoder_inputs[0].get_shape().with_rank(2)
else:
fixed_batch_size = decoder_inputs[0].get_shape().with_rank_at_least(1)[0]
if fixed_batch_size.value:
batch_size = fixed_batch_size.value
else:
batch_size = tf.shape(decoder_inputs[0])[0]
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length is not None: # Prepare variables
zero_logit = tf.zeros(
tf.stack([batch_size, num_decoder_symbols]), decoder_inputs[0].dtype)
zero_logit.set_shape(
tensor_shape.TensorShape([fixed_batch_size.value, num_decoder_symbols]))
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
for time, input_ in enumerate(decoder_inputs):
# if time == 0:
# hidden_state = zero_state(num_decoder_symbols, batch_size)
if time > 0: tf.get_variable_scope().reuse_variables()
# pylint: disable=cell-var-from-loop
# call_cell = lambda: cell(input_, state)
generate_logit = lambda: _linear(decoder_inputs[time], num_decoder_symbols, True)
# pylint: enable=cell-var-from-loop
if sequence_length is not None:
logit = _step(
time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit)
else:
logit = generate_logit
decoder_outputs.append(logit)
return decoder_outputs
评论列表
文章目录