def _create_rnn_cell(self):
"""
Creates a single RNN cell according to the architecture of this RNN.
Returns
-------
rnn cell
A single RNN cell according to the architecture of this RNN
"""
keep_prob = 1.0 if self.keep_prob is None else self.keep_prob
if self.cell_type == CellType.GRU:
return DropoutWrapper(GRUCell(self.num_units), keep_prob, keep_prob)
elif self.cell_type == CellType.LSTM:
return DropoutWrapper(LSTMCell(self.num_units), keep_prob, keep_prob)
else:
raise ValueError("unknown cell type: {}".format(self.cell_type))
评论列表
文章目录