seq_labeling.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:joint-slu-lm 作者: HadoopIt 项目源码 文件源码
def generate_sequence_output(encoder_outputs, 
                  encoder_state,
                  num_decoder_symbols,
                  sequence_length,
                num_heads=1,
                dtype=dtypes.float32,
                use_attention=True,
                loop_function=None,
                scope=None,
                DNN_at_output=False,
                forward_only=False):
  with variable_scope.variable_scope(scope or "non-attention_RNN"):
    attention_encoder_outputs = list()
    sequence_attention_weights = list()

    # copy over logits once out of sequence_length
    if encoder_outputs[0].get_shape().ndims != 1:
      (fixed_batch_size, output_size) = encoder_outputs[0].get_shape().with_rank(2)
    else:
      fixed_batch_size = encoder_outputs[0].get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value: 
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(encoder_outputs[0])[0]
    if sequence_length is not None:
      sequence_length = math_ops.to_int32(sequence_length)
    if sequence_length is not None:  # Prepare variables
      zero_logit = array_ops.zeros(
          array_ops.pack([batch_size, num_decoder_symbols]), encoder_outputs[0].dtype)
      zero_logit.set_shape(
          tensor_shape.TensorShape([fixed_batch_size.value, num_decoder_symbols]))
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

    for time, input_ in enumerate(encoder_outputs):
      if time > 0: variable_scope.get_variable_scope().reuse_variables()

      if not DNN_at_output:  
        generate_logit = lambda: linear_transformation(encoder_outputs[time], output_size, num_decoder_symbols)
      else:
        generate_logit = lambda: multilayer_perceptron(encoder_outputs[time], output_size, 200, num_decoder_symbols, forward_only=forward_only)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        logit = _step(
            time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit)
      else:
        logit = generate_logit
      attention_encoder_outputs.append(logit)   
    if DNN_at_output:  
      regularizers = get_multilayer_perceptron_regularizers()
    else:
      regularizers = get_linear_transformation_regularizers()
  return attention_encoder_outputs, sequence_attention_weights, regularizers
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号