def __init__(self, n_hidden, cell="GRU"):
"""
qa_rnn module init.
:param n_hidden: num of hidden units
:param cell: gru|lstm|basic_rnn
"""
self.rnn_cell = rnn.BasicRNNCell(num_units=n_hidden)
if cell == "GRU":
self.rnn_cell = rnn.GRUCell(num_units=n_hidden)
elif cell == "LSTM":
self.rnn_cell = rnn.LSTMCell(num_units=n_hidden)
else:
raise Exception(cell + " not supported.")
评论列表
文章目录