vaelm.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:vaelm 作者: TatsuyaShirakawa 项目源码 文件源码
def __call__(self, w, train=True, dpratio=0.5):

        x = self.embed(w)
        self.maybe_init_state(len(x.data), x.data.dtype)

        for i in range(self.num_layers):

            if self.ignore_label is not None:
                enable = (x.data != 0)

            c = F.dropout(self.get_c(i), train=train, ratio=dpratio)
            h = F.dropout(self.get_h(i), train=train, ratio=dpratio)
            x = F.dropout(x, train=train, ratio=dpratio)
            c, h = self.get_l(i)(c, h, x)

            if self.ignore_label != None:
                self.set_c(i, F.where(enable, c, self.get_c(i)))
                self.set_h(i, F.where(enable, h, self.get_h(i)))
            else:
                self.set_c(i, c)
                self.set_h(i, h)

            x = self.get_h(i)

        x = F.dropout(x, train=train, ratio=dpratio)
        return self.hy(x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号