def _forward(x, reduce_step, initial_states, U, mask=None):
'''Forward recurrence of the linear chain crf.'''
def _forward_step(energy_matrix_t, states):
alpha_tm1 = states[-1]
new_states = reduce_step(K.expand_dims(alpha_tm1, 2) + energy_matrix_t)
return new_states[0], new_states
U_shared = K.expand_dims(K.expand_dims(U, 0), 0)
if mask is not None:
mask = K.cast(mask, K.floatx())
mask_U = K.expand_dims(K.expand_dims(mask[:, :-1] * mask[:, 1:], 2), 3)
U_shared = U_shared * mask_U
inputs = K.expand_dims(x[:, 1:, :], 2) + U_shared
inputs = K.concatenate([inputs, K.zeros_like(inputs[:, -1:, :, :])], axis=1)
last, values, _ = K.rnn(_forward_step, inputs, initial_states)
return last, values
评论列表
文章目录