encoder_decoder.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:workspace 作者: nojima 项目源码 文件源码
def translate(self, sentence: np.ndarray, max_length: int = 30) -> List[int]:
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            sentence = sentence[::-1]

            embedded_xs = self._embed_input(sentence)
            hidden_states, cell_states, attentions = self._encoder(None, None, [embedded_xs])

            wid = EOS
            result = []

            for i in range(max_length):
                output, hidden_states, cell_states = \
                    self._translate_one_word(wid, hidden_states, cell_states, attentions)

                wid = np.argmax(output.data)
                if wid == EOS:
                    break
                result.append(wid)

            return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号