def __init__(self, config, embeddings=None, ntags=None):
# build word embedding
word_ids = Input(batch_shape=(None, None), dtype='int32')
if embeddings is None:
word_embeddings = Embedding(input_dim=config.vocab_size,
output_dim=config.word_embedding_size,
mask_zero=True)(word_ids)
else:
word_embeddings = Embedding(input_dim=embeddings.shape[0],
output_dim=embeddings.shape[1],
mask_zero=True,
weights=[embeddings])(word_ids)
# build character based word embedding
char_ids = Input(batch_shape=(None, None, None), dtype='int32')
char_embeddings = Embedding(input_dim=config.char_vocab_size,
output_dim=config.char_embedding_size,
mask_zero=True
)(char_ids)
s = K.shape(char_embeddings)
char_embeddings = Lambda(lambda x: K.reshape(x, shape=(-1, s[-2], config.char_embedding_size)))(char_embeddings)
fwd_state = LSTM(config.num_char_lstm_units, return_state=True)(char_embeddings)[-2]
bwd_state = LSTM(config.num_char_lstm_units, return_state=True, go_backwards=True)(char_embeddings)[-2]
char_embeddings = Concatenate(axis=-1)([fwd_state, bwd_state])
# shape = (batch size, max sentence length, char hidden size)
char_embeddings = Lambda(lambda x: K.reshape(x, shape=[-1, s[1], 2 * config.num_char_lstm_units]))(char_embeddings)
# combine characters and word
x = Concatenate(axis=-1)([word_embeddings, char_embeddings])
x = Dropout(config.dropout)(x)
x = Bidirectional(LSTM(units=config.num_word_lstm_units, return_sequences=True))(x)
x = Dropout(config.dropout)(x)
x = Dense(config.num_word_lstm_units, activation='tanh')(x)
x = Dense(ntags)(x)
self.crf = ChainCRF()
pred = self.crf(x)
sequence_lengths = Input(batch_shape=(None, 1), dtype='int32')
self.model = Model(inputs=[word_ids, char_ids, sequence_lengths], outputs=[pred])
self.config = config
评论列表
文章目录