def make_target(self, batch, lengths):
if use_cuda:
eos = Variable(torch.stack([torch.Tensor([0,0,0,0,1])]\
*batch.size()[1]).cuda()).unsqueeze(0)
else:
eos = Variable(torch.stack([torch.Tensor([0,0,0,0,1])]\
*batch.size()[1])).unsqueeze(0)
batch = torch.cat([batch, eos], 0)
mask = torch.zeros(Nmax+1, batch.size()[1])
for indice,length in enumerate(lengths):
mask[:length,indice] = 1
if use_cuda:
mask = Variable(mask.cuda()).detach()
else:
mask = Variable(mask).detach()
dx = torch.stack([Variable(batch.data[:,:,0])]*hp.M,2).detach()
dy = torch.stack([Variable(batch.data[:,:,1])]*hp.M,2).detach()
p1 = Variable(batch.data[:,:,2]).detach()
p2 = Variable(batch.data[:,:,3]).detach()
p3 = Variable(batch.data[:,:,4]).detach()
p = torch.stack([p1,p2,p3],2)
return mask,dx,dy,p
评论列表
文章目录