lstm_decoder.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DSTC6-End-to-End-Conversation-Modeling 作者: dialogtekgeek 项目源码 文件源码
def update(self, s, i):
        """Update decoder state

        Args:
            s (any): Current (hidden, cell) states.  If ``None`` is specified 
                     zero-vector is used.
            i (int): input label.
        Return:
            (~chainer.Variable) updated decoder state
        """
        if cuda.get_device_from_array(s[0].data).id >= 0:
            xp = cuda.cupy
        else:
            xp = np

        v = chainer.Variable(xp.array([i],dtype=np.int32))
        x = self.embed(v)
        if s is not None:
            hy, cy, dy = self.lstm(s[0], s[1], [x])
        else:
            hy, cy, dy = self.lstm(None, None, [x])

        return hy, cy, dy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号