def _get_rnn_unit(self, rnn_unit):
if rnn_unit == 'lstm':
fw_cell = rnn.BasicLSTMCell(self._nb_hidden, forget_bias=1., state_is_tuple=True)
bw_cell = rnn.BasicLSTMCell(self._nb_hidden, forget_bias=1., state_is_tuple=True)
elif rnn_unit == 'gru':
fw_cell = rnn.GRUCell(self._nb_hidden)
bw_cell = rnn.GRUCell(self._nb_hidden)
else:
raise ValueError('rnn_unit must in (lstm, gru)!')
return fw_cell, bw_cell
评论列表
文章目录