def embedding_encoder(encoder_inputs,
cell,
embedding,
num_symbols,
embedding_size,
bidirectional=False,
dtype=None,
weight_initializer=None,
scope=None):
with variable_scope.variable_scope(
scope or "embedding_encoder", dtype=dtype) as scope:
dtype = scope.dtype
# Encoder.
if not embedding:
embedding = variable_scope.get_variable("embedding", [num_symbols, embedding_size],
initializer=weight_initializer())
emb_inp = [embedding_ops.embedding_lookup(embedding, i) for i in encoder_inputs]
if bidirectional:
_, output_state_fw, output_state_bw = rnn.bidirectional_rnn(cell, cell, emb_inp,
dtype=dtype)
encoder_state = tf.concat(1, [output_state_fw, output_state_bw])
else:
_, encoder_state = rnn.rnn(
cell, emb_inp, dtype=dtype)
return encoder_state
seq2seq.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录