hnmt.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:hnmt 作者: robertostling 项目源码 文件源码
def xent(self, inputs, inputs_mask, chars, chars_mask,
             outputs, outputs_mask, attention):
        pred_outputs, pred_attention = self(
                inputs, inputs_mask, chars, chars_mask, outputs, outputs_mask)
        outputs_xent = batch_sequence_crossentropy(
                pred_outputs, outputs[1:], outputs_mask[1:])
        # Note that pred_attention will contain zero elements for masked-out
        # character positions, to avoid trouble with log() we add 1 for zero
        # element of attention (which after multiplication will be removed
        # anyway).
        batch_size = attention.shape[1].astype(theano.config.floatX)
        attention_mask = (inputs_mask.dimshuffle('x', 1, 0) *
                          outputs_mask[1:].dimshuffle(0, 1, 'x')
                          ).astype(theano.config.floatX)
        epsilon = 1e-6
        attention_xent = (
                   -attention[1:]
                 * T.log(epsilon + pred_attention + (1-attention_mask))
                 * attention_mask).sum() / batch_size
        return outputs_xent, attention_xent
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号