seq2seq.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Variational-Recurrent-Autoencoder-Tensorflow 作者: Chung-I 项目源码 文件源码
def embedding_encoder(encoder_inputs,
                      cell,
                      embedding,
                      num_symbols,
                      embedding_size,
                      bidirectional=False,
                      dtype=None,
                      weight_initializer=None,
                      scope=None):

  with variable_scope.variable_scope(
      scope or "embedding_encoder", dtype=dtype) as scope:
    dtype = scope.dtype
    # Encoder.
    if not embedding:
      embedding = variable_scope.get_variable("embedding", [num_symbols, embedding_size],
              initializer=weight_initializer())
    emb_inp = [embedding_ops.embedding_lookup(embedding, i) for i in encoder_inputs]
    if bidirectional:
      _, output_state_fw, output_state_bw = rnn.bidirectional_rnn(cell, cell, emb_inp,
              dtype=dtype)
      encoder_state = tf.concat(1, [output_state_fw, output_state_bw])
    else:
      _, encoder_state = rnn.rnn(
        cell, emb_inp, dtype=dtype)

    return encoder_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号