sketch_rnn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Pytorch-Sketch-RNN 作者: alexis-jacq 项目源码 文件源码
def make_target(self, batch, lengths):
        if use_cuda:
            eos = Variable(torch.stack([torch.Tensor([0,0,0,0,1])]\
                *batch.size()[1]).cuda()).unsqueeze(0)
        else:
            eos = Variable(torch.stack([torch.Tensor([0,0,0,0,1])]\
                *batch.size()[1])).unsqueeze(0)
        batch = torch.cat([batch, eos], 0)
        mask = torch.zeros(Nmax+1, batch.size()[1])
        for indice,length in enumerate(lengths):
            mask[:length,indice] = 1
        if use_cuda:
            mask = Variable(mask.cuda()).detach()
        else:
            mask = Variable(mask).detach()
        dx = torch.stack([Variable(batch.data[:,:,0])]*hp.M,2).detach()
        dy = torch.stack([Variable(batch.data[:,:,1])]*hp.M,2).detach()
        p1 = Variable(batch.data[:,:,2]).detach()
        p2 = Variable(batch.data[:,:,3]).detach()
        p3 = Variable(batch.data[:,:,4]).detach()
        p = torch.stack([p1,p2,p3],2)
        return mask,dx,dy,p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号