draw_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:draw_pytorch 作者: chenzhaomin123 项目源码 文件源码
def attn_window(self,h_dec):
        params = self.dec_linear(h_dec)
        gx_,gy_,log_sigma_2,log_delta,log_gamma = params.split(1,1)  #21

        # gx_ = Variable(torch.ones(4,1))
        # gy_ = Variable(torch.ones(4, 1) * 2)
        # log_sigma_2 = Variable(torch.ones(4, 1) * 3)
        # log_delta = Variable(torch.ones(4, 1) * 4)
        # log_gamma = Variable(torch.ones(4, 1) * 5)

        gx = (self.A + 1) / 2 * (gx_ + 1)    # 22
        gy = (self.B + 1) / 2 * (gy_ + 1)    # 23
        delta = (max(self.A,self.B) - 1) / (self.N - 1) * torch.exp(log_delta)  # 24
        sigma2 = torch.exp(log_sigma_2)
        gamma = torch.exp(log_gamma)

        return self.filterbank(gx,gy,sigma2,delta),gamma
    # correct
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号