def get_output_for(self, input, **kwargs):
def norm_fn(f, mask, label, previous, W_sim):
# f: inst * class, mask: inst, previous: inst * class, W_sim: class * class
next = previous.dimshuffle(0, 1, 'x') + f.dimshuffle(0, 'x', 1) + W_sim.dimshuffle('x', 0, 1)
if COST:
next = next + COST_CONST * (1.0 - T.extra_ops.to_one_hot(label, self.num_classes).dimshuffle(0, 'x', 1))
# next: inst * prev * cur
next = theano_logsumexp(next, axis = 1)
# next: inst * class
mask = mask.dimshuffle(0, 'x')
next = previous * (1.0 - mask) + next * mask
return next
f = T.dot(input, self.W)
# f: inst * time * class
initial = f[:, 0, :]
if CRF_INIT:
initial = initial + self.W_init[0].dimshuffle('x', 0)
if COST:
initial = initial + COST_CONST * (1.0 - T.extra_ops.to_one_hot(self.label_input[:, 0], self.num_classes))
outputs, _ = theano.scan(fn = norm_fn, \
sequences = [f.dimshuffle(1, 0, 2)[1: ], self.mask_input.dimshuffle(1, 0)[1: ], self.label_input.dimshuffle(1, 0)[1:]], \
outputs_info = initial, non_sequences = [self.W_sim], strict = True)
norm = T.sum(theano_logsumexp(outputs[-1], axis = 1))
f_pot = (f.reshape((-1, f.shape[-1]))[T.arange(f.shape[0] * f.shape[1]), self.label_input.flatten()] * self.mask_input.flatten()).sum()
if CRF_INIT:
f_pot += self.W_init[0][self.label_input[:, 0]].sum()
labels = self.label_input
# labels: inst * time
shift_labels = T.roll(labels, -1, axis = 1)
mask = self.mask_input
# mask : inst * time
shift_mask = T.roll(mask, -1, axis = 1)
g_pot = (self.W_sim[labels.flatten(), shift_labels.flatten()] * mask.flatten() * shift_mask.flatten()).sum()
return - (f_pot + g_pot - norm) / f.shape[0]
评论列表
文章目录