def init_hidden(self):
if self.bidirectional == True:
if self.use_lstm == True:
return [Variable(torch.zeros(2, self.batch_size, self.word_gru_hidden)), Variable(torch.zeros(2, self.batch_size, self.word_gru_hidden)) ]
else:
return Variable(torch.zeros(2, self.batch_size, self.word_gru_hidden))
else:
if self.use_lstm == True:
return [Variable(torch.zeros(1, self.batch_size, self.word_gru_hidden)), Variable(torch.zeros(1, self.batch_size, self.word_gru_hidden)) ]
else:
return Variable(torch.zeros(1, self.batch_size, self.word_gru_hidden))
评论列表
文章目录