decoder.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:RFR-solution 作者: baoblackcoal 项目源码 文件源码
def embedding_attention_decoder(initial_state,
                                attention_states,
                                cell,
                                num_symbols,
                                time_steps,
                                batch_size,
                                embedding_size,
                                output_size=None,
                                output_projection=None,
                                feed_previous=False,
                                update_embedding_for_previous=True,
                                dtype=None,
                                scope=None):
  if output_size is None:
    output_size = cell.output_size
  if output_projection is not None:
    proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
    proj_biases.get_shape().assert_is_compatible_with([num_symbols])

  with variable_scope.variable_scope(
      scope or "embedding_attention_decoder", dtype=dtype) as scope:

    embedding = variable_scope.get_variable("embedding",
                                            [num_symbols, embedding_size])
    loop_function = tf.nn.seq2seq._extract_argmax_and_embed(
        embedding, output_projection,
        update_embedding_for_previous) if feed_previous else None
    return attention_decoder(
        initial_state,
        attention_states,
        cell,
        num_symbols,
        time_steps,
        batch_size,
        output_size=output_size,
        loop_function=loop_function)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号