LSTMEncDecAttn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def init_hx(self, xs):
        hx_shape = self.n_layers * self.direction
        with cuda.get_device_from_id(self._device_id):
            if args.chainer_version_check[0] == 2:
                hx = chainer.Variable(
                    self.xp.zeros((hx_shape, xs.data.shape[1], self.out_size),
                                  dtype=xs.dtype))
            else:
                hx = chainer.Variable(
                    self.xp.zeros((hx_shape, xs.data.shape[1], self.out_size),
                                  dtype=xs.dtype),
                    volatile='auto')
        return hx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号