encoder_decoder.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:workspace 作者: nojima 项目源码 文件源码
def __call__(self, xs: List[Variable], ys: List[Variable]) -> Variable:
        batch_size = len(xs)
        xs = [x[::-1] for x in xs]

        eos = np.array([EOS], dtype=np.int32)
        ys_in = [F.concat((eos, y), axis=0) for y in ys]
        ys_out = [F.concat((y, eos), axis=0) for y in ys]

        embedded_xs = [self._embed_input(x) for x in xs]
        embedded_ys = [self._embed_output(y) for y in ys_in]

        hidden_states, cell_states, attentions = self._encoder(None, None, embedded_xs)
        _, _, embedded_outputs = self._decoder(hidden_states, cell_states, embedded_ys)

        loss = 0
        for embedded_output, y, attention in zip(embedded_outputs, ys_out, attentions):
            if self._use_attention:
                output = self._calculate_attention_layer_output(embedded_output, attention)
            else:
                output = self._extract_output(embedded_output)
            loss += F.softmax_cross_entropy(output, y)
        loss /= batch_size

        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号