def get_output_for(self, input, **kwargs):
def max_fn(f, mask, prev_score, prev_back, W_sim):
next_score = prev_score.dimshuffle(0, 1, 'x') + f.dimshuffle(0, 'x', 1) + W_sim.dimshuffle('x', 0, 1)
next_back = T.argmax(next_score, axis = 1)
next_score = T.max(next_score, axis = 1)
mask = mask.dimshuffle(0, 'x')
next_score = next_score * mask + prev_score * (1.0 - mask)
next_back = next_back * mask + prev_back * (1.0 - mask)
next_back = T.cast(next_back, 'int32')
return [next_score, next_back]
def produce_fn(back, mask, prev_py):
# back: inst * class, prev_py: inst, mask: inst
next_py = back[T.arange(prev_py.shape[0]), prev_py]
next_py = mask * next_py + (1.0 - mask) * prev_py
next_py = T.cast(next_py, 'int32')
return next_py
f = T.dot(input, self.W)
init_score, init_back = f[:, 0, :], T.zeros_like(f[:, 0, :], dtype = 'int32')
if CRF_INIT:
init_score = init_score + self.W_init[0].dimshuffle('x', 0)
([scores, backs], _) = theano.scan(fn = max_fn, \
sequences = [f.dimshuffle(1, 0, 2)[1: ], self.mask_input.dimshuffle(1, 0)[1: ]], \
outputs_info = [init_score, init_back], non_sequences = [self.W_sim], strict = True)
init_py = T.argmax(scores[-1], axis = 1)
init_py = T.cast(init_py, 'int32')
# init_py: inst, backs: time * inst * class
pys, _ = theano.scan(fn = produce_fn, \
sequences = [backs, self.mask_input.dimshuffle(1, 0)[1:]], outputs_info = [init_py], go_backwards = True)
# pys: (rev_time - 1) * inst
pys = pys.dimshuffle(1, 0)[:, :: -1]
# pys : inst * (time - 1)
return T.concatenate([pys, init_py.dimshuffle(0, 'x')], axis = 1)
评论列表
文章目录