def init_hidden(self, height, width):
self.height = height
self.width = width
self.batch = height * width
self.cell_state = Variable(
torch.zeros(
self.lstm_layer,
self.batch,
self.hidden_dim))
self.hidden_state = Variable(
torch.zeros(
self.lstm_layer,
self.batch,
self.hidden_dim))
if self.on_gpu:
self.cell_state = self.cell_state.cuda()
self.hidden_state = self.hidden_state.cuda()
评论列表
文章目录