net.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:convolutional_seq2seq 作者: soskek 项目源码 文件源码
def translate(self, x_block, max_length=50):
        # TODO: efficient inference by re-using convolution result
        with chainer.no_backprop_mode():
            with chainer.using_config('train', False):
                # if isinstance(x_block, list):
                x_block = source_pad_concat_convert(
                    x_block, device=None)
                batch, x_length = x_block.shape
                y_block = self.xp.zeros((batch, 1), dtype=x_block.dtype)
                eos_flags = self.xp.zeros((batch, ), dtype=x_block.dtype)
                result = []
                for i in range(max_length):
                    log_prob_tail = self(x_block, y_block, y_block,
                                         get_prediction=True)
                    ys = self.xp.argmax(log_prob_tail.data, axis=1).astype('i')
                    result.append(ys)
                    y_block = F.concat([y_block, ys[:, None]], axis=1).data
                    eos_flags += (ys == 0)
                    if self.xp.all(eos_flags):
                        break

        result = cuda.to_cpu(self.xp.stack(result).T)

        # Remove EOS taggs
        outs = []
        for y in result:
            inds = np.argwhere(y == 0)
            if len(inds) > 0:
                y = y[:inds[0, 0]]
            if len(y) == 0:
                y = np.array([1], 'i')
            outs.append(y)
        return outs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号