def _load_model(self):
"""
Creates the encoder decoder model
:return: None
"""
# Initial memory value for recurrence.
self.prev_mem = tf.zeros((self.train_batch_size, self.memory_dim))
# choose RNN/GRU/LSTM cell
with tf.variable_scope("train_test", reuse=True):
cell = self.get_cell()
# Stacks layers of RNN's to form a stacked decoder
self.cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.num_layers)
# embedding model
if not self.attention:
with tf.variable_scope("train_test"):
self.dec_outputs, self.dec_memory = tf.nn.seq2seq.embedding_rnn_seq2seq(
self.enc_inp, self.dec_inp, self.cell,
self.vocab_size, self.vocab_size, self.seq_length)
with tf.variable_scope("train_test", reuse=True):
self.dec_outputs_tst, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
self.enc_inp, self.dec_inp, self.cell,
self.vocab_size, self.vocab_size, self.seq_length, feed_previous=True)
else:
with tf.variable_scope("train_test"):
self.dec_outputs, self.dec_memory = tf.nn.seq2seq.embedding_attention_seq2seq(
self.enc_inp, self.dec_inp, self.cell,
self.vocab_size, self.vocab_size, self.seq_length)
with tf.variable_scope("train_test", reuse=True):
self.dec_outputs_tst, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
self.enc_inp, self.dec_inp, self.cell,
self.vocab_size, self.vocab_size, self.seq_length, feed_previous=True)
评论列表
文章目录