many2one_seq2seq.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:seq2seq_parser 作者: trangham283 项目源码 文件源码
def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
                          embedding_size, output_projection=None,
                          feed_previous=False,
                          update_embedding_for_previous=True, scope=None):
  """RNN decoder with embedding and a pure-decoding option.

  """
  if output_projection is not None:
    proj_weights = ops.convert_to_tensor(output_projection[0],
                                         dtype=dtypes.float32)
    proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
    proj_biases = ops.convert_to_tensor(
        output_projection[1], dtype=dtypes.float32)
    proj_biases.get_shape().assert_is_compatible_with([num_symbols])

  with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
    embedding = variable_scope.get_variable("embedding",
            [num_symbols, embedding_size])
    loop_function = _extract_argmax_and_embed(
        embedding, output_projection,
        update_embedding_for_previous) if feed_previous else None
    emb_inp = (
        embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs)
    return rnn_decoder(emb_inp, initial_state, cell,
                       loop_function=loop_function)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号