def _set_train_model(self):
"""
define train graph
:return:
"""
# Create the internal multi-layer cell for our RNN.
if use_lstm:
single_cell1 = LSTMCell(self.enc_hidden_size)
single_cell2 = LSTMCell(self.dec_hidden_size)
else:
single_cell1 = GRUCell(self.enc_hidden_size)
single_cell2 = GRUCell(self.dec_hidden_size)
enc_cell = MultiRNNCell([single_cell1 for _ in range(self.enc_num_layers)])
dec_cell = MultiRNNCell([single_cell2 for _ in range(self.dec_num_layers)])
self.encoder_cell = enc_cell
self.decoder_cell = dec_cell
self._make_graph(forward_only)
self.saver = tf.train.Saver(tf.global_variables())
neuralnet_node_attnseq2seq.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录