def fit_model(self, X, Y, use_attention, att_context, bidirectional):
print >>sys.stderr, "Input shape:", X.shape, Y.shape
early_stopping = EarlyStopping(patience = 2)
num_classes = len(self.label_ind)
if bidirectional:
tagger = Graph()
tagger.add_input(name='input', input_shape=X.shape[1:])
if use_attention:
tagger.add_node(TensorAttention(X.shape[1:], context=att_context), name='attention', input='input')
lstm_input_node = 'attention'
else:
lstm_input_node = 'input'
tagger.add_node(LSTM(X.shape[-1]/2, return_sequences=True), name='forward', input=lstm_input_node)
tagger.add_node(LSTM(X.shape[-1]/2, return_sequences=True, go_backwards=True), name='backward', input=lstm_input_node)
tagger.add_node(TimeDistributedDense(num_classes, activation='softmax'), name='softmax', inputs=['forward', 'backward'], merge_mode='concat', concat_axis=-1)
tagger.add_output(name='output', input='softmax')
tagger.summary()
tagger.compile('adam', {'output':'categorical_crossentropy'})
tagger.fit({'input':X, 'output':Y}, validation_split=0.1, callbacks=[early_stopping], show_accuracy=True, nb_epoch=100, batch_size=10)
else:
tagger = Sequential()
word_proj_dim = 50
if use_attention:
_, input_len, timesteps, input_dim = X.shape
tagger.add(HigherOrderTimeDistributedDense(input_dim=input_dim, output_dim=word_proj_dim))
att_input_shape = (input_len, timesteps, word_proj_dim)
print >>sys.stderr, "Attention input shape:", att_input_shape
tagger.add(Dropout(0.5))
tagger.add(TensorAttention(att_input_shape, context=att_context))
else:
_, input_len, input_dim = X.shape
tagger.add(TimeDistributedDense(input_dim=input_dim, input_length=input_len, output_dim=word_proj_dim))
tagger.add(LSTM(input_dim=word_proj_dim, output_dim=word_proj_dim, input_length=input_len, return_sequences=True))
tagger.add(TimeDistributedDense(num_classes, activation='softmax'))
tagger.summary()
tagger.compile(loss='categorical_crossentropy', optimizer='adam')
tagger.fit(X, Y, validation_split=0.1, callbacks=[early_stopping], show_accuracy=True, batch_size=10)
return tagger
评论列表
文章目录