net.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:NlpUtil 作者: trtd56 项目源码 文件源码
def __call__(self, xs, train=True):
        batch = len(xs)
        if self.hx is None:
            xp = self.xp
            self.hx = Variable(
                xp.zeros((self.n_layers, batch, self.state_size), dtype=xs[0].dtype),
                volatile='auto')
        if self.cx is None:
            xp = self.xp
            self.cx = Variable(
                xp.zeros((self.n_layers, batch, self.state_size), dtype=xs[0].dtype),
                volatile='auto')
        hy, cy, ys = super(NStepLSTM, self).__call__(self.hx, self.cx, xs, train)
        self.hx, self.cx = hy, cy
        return ys
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号