def _make_rnn_cell(self, i):
if self._cell_type == "lstm":
cell = tf.contrib.rnn.LSTMCell(self.output_size)
elif self._cell_type == "gru":
cell = tf.contrib.rnn.GRUCell(self.output_size)
elif self._cell_type == "basic-tanh":
cell = tf.contrib.rnn.BasicRNNCell(self.output_size)
else:
raise ValueError("Invalid RNN Cell type")
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self._dropout, seed=8 + 33 * i)
return cell
encoders.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录