def __init__(self, batch_size, num_tokens, embed_size, word_gru_hidden, bidirectional= True, init_range=0.1, use_lstm=False):
super(AttentionWordRNN, self).__init__()
self.batch_size = batch_size
self.num_tokens = num_tokens
self.embed_size = embed_size
self.word_gru_hidden = word_gru_hidden
self.bidirectional = bidirectional
self.use_lstm = use_lstm
self.lookup = nn.Embedding(num_tokens, embed_size)
if bidirectional == True:
if use_lstm:
print("inside using LSTM")
self.word_gru = nn.LSTM(embed_size, word_gru_hidden, bidirectional= True)
else:
self.word_gru = nn.GRU(embed_size, word_gru_hidden, bidirectional= True)
self.weight_W_word = nn.Parameter(torch.Tensor(2* word_gru_hidden, 2*word_gru_hidden))
self.bias_word = nn.Parameter(torch.Tensor(2* word_gru_hidden,1))
self.weight_proj_word = nn.Parameter(torch.Tensor(2*word_gru_hidden, 1))
else:
if use_lstm:
self.word_gru = nn.LSTM(embed_size, word_gru_hidden, bidirectional= False)
else:
self.word_gru = nn.GRU(embed_size, word_gru_hidden, bidirectional= False)
self.weight_W_word = nn.Parameter(torch.Tensor(word_gru_hidden, word_gru_hidden))
self.bias_word = nn.Parameter(torch.Tensor(word_gru_hidden,1))
self.weight_proj_word = nn.Parameter(torch.Tensor(word_gru_hidden, 1))
self.softmax_word = nn.Softmax()
self.weight_W_word.data.uniform_(-init_range, init_range)
self.weight_proj_word.data.uniform_(-init_range, init_range)
评论列表
文章目录