def forward(self, x):
outputs = []
h_t = Variable(torch.zeros(x.size(0), self.hidden_size).cuda())
c_t = Variable(torch.zeros(x.size(0), self.hidden_size).cuda())
for i, input_t in enumerate(x.chunk(x.size(1), dim=1)):
input_t = input_t.contiguous().view(input_t.size()[0], 1)
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
outputs += [c_t]
outputs = torch.stack(outputs, 1).squeeze(2)
shp=(outputs.size()[0], outputs.size()[1])
out = outputs.contiguous().view(shp[0] *shp[1] , self.hidden_size)
out = self.fc(out)
out = out.view(shp[0], shp[1], self.num_classes)
return out
评论列表
文章目录