decoders.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:opinatt 作者: epochx 项目源码 文件源码
def seq_labeling_decoder_linear(decoder_inputs, num_decoder_symbols,
                                scope=None, sequence_length=None, dtype=tf.float32):
  with tf.variable_scope(scope or "non-attention_RNN"):

    decoder_outputs = list()

    # copy over logits once out of sequence_length
    if decoder_inputs[0].get_shape().ndims != 1:
      (fixed_batch_size, output_size) = decoder_inputs[0].get_shape().with_rank(2)
    else:
      fixed_batch_size = decoder_inputs[0].get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = tf.shape(decoder_inputs[0])[0]
    if sequence_length is not None:
      sequence_length = math_ops.to_int32(sequence_length)
    if sequence_length is not None:  # Prepare variables
      zero_logit = tf.zeros(
        tf.stack([batch_size, num_decoder_symbols]), decoder_inputs[0].dtype)
      zero_logit.set_shape(
        tensor_shape.TensorShape([fixed_batch_size.value, num_decoder_symbols]))
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

    for time, input_ in enumerate(decoder_inputs):
      # if time == 0:
      #  hidden_state = zero_state(num_decoder_symbols, batch_size)
      if time > 0: tf.get_variable_scope().reuse_variables()
      # pylint: disable=cell-var-from-loop
      # call_cell = lambda: cell(input_, state)
      generate_logit = lambda: _linear(decoder_inputs[time], num_decoder_symbols, True)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        logit = _step(
          time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit)
      else:
        logit = generate_logit
      decoder_outputs.append(logit)

  return decoder_outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号