def xent(self, inputs, inputs_mask, chars, chars_mask,
outputs, outputs_mask, attention):
pred_outputs, pred_attention = self(
inputs, inputs_mask, chars, chars_mask, outputs, outputs_mask)
outputs_xent = batch_sequence_crossentropy(
pred_outputs, outputs[1:], outputs_mask[1:])
# Note that pred_attention will contain zero elements for masked-out
# character positions, to avoid trouble with log() we add 1 for zero
# element of attention (which after multiplication will be removed
# anyway).
batch_size = attention.shape[1].astype(theano.config.floatX)
attention_mask = (inputs_mask.dimshuffle('x', 1, 0) *
outputs_mask[1:].dimshuffle(0, 1, 'x')
).astype(theano.config.floatX)
epsilon = 1e-6
attention_xent = (
-attention[1:]
* T.log(epsilon + pred_attention + (1-attention_mask))
* attention_mask).sum() / batch_size
return outputs_xent, attention_xent
评论列表
文章目录