def _backward(gamma, mask):
'''Backward recurrence of the linear chain crf.'''
gamma = K.cast(gamma, 'int32')
def _backward_step(gamma_t, states):
y_tm1 = K.squeeze(states[0], 0)
y_t = KC.batch_gather(gamma_t, y_tm1)
return y_t, [K.expand_dims(y_t, 0)]
initial_states = [K.expand_dims(K.zeros_like(gamma[:, 0, 0]), 0)]
_, y_rev, _ = K.rnn(_backward_step,
gamma,
initial_states,
go_backwards=True)
y = K.reverse(y_rev, 1)
if mask is not None:
mask = K.cast(mask, dtype='int32')
# mask output
y *= mask
# set masked values to -1
y += -(1 - mask)
return y
评论列表
文章目录