nn_passage_tagger.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:sciDT 作者: edvisees 项目源码 文件源码
def fit_model(self, X, Y, use_attention, att_context, bidirectional):
    print >>sys.stderr, "Input shape:", X.shape, Y.shape
    early_stopping = EarlyStopping(patience = 2)
    num_classes = len(self.label_ind)
    if bidirectional:
      tagger = Graph()
      tagger.add_input(name='input', input_shape=X.shape[1:])
      if use_attention:
        tagger.add_node(TensorAttention(X.shape[1:], context=att_context), name='attention', input='input')
        lstm_input_node = 'attention'
      else:
        lstm_input_node = 'input'
      tagger.add_node(LSTM(X.shape[-1]/2, return_sequences=True), name='forward', input=lstm_input_node)
      tagger.add_node(LSTM(X.shape[-1]/2, return_sequences=True, go_backwards=True), name='backward', input=lstm_input_node)
      tagger.add_node(TimeDistributedDense(num_classes, activation='softmax'), name='softmax', inputs=['forward', 'backward'], merge_mode='concat', concat_axis=-1)
      tagger.add_output(name='output', input='softmax')
      tagger.summary()
      tagger.compile('adam', {'output':'categorical_crossentropy'})
      tagger.fit({'input':X, 'output':Y}, validation_split=0.1, callbacks=[early_stopping], show_accuracy=True, nb_epoch=100, batch_size=10)
    else:
      tagger = Sequential()
      word_proj_dim = 50
      if use_attention:
        _, input_len, timesteps, input_dim = X.shape
        tagger.add(HigherOrderTimeDistributedDense(input_dim=input_dim, output_dim=word_proj_dim))
        att_input_shape = (input_len, timesteps, word_proj_dim)
        print >>sys.stderr, "Attention input shape:", att_input_shape
        tagger.add(Dropout(0.5))
        tagger.add(TensorAttention(att_input_shape, context=att_context))
      else:
        _, input_len, input_dim = X.shape
        tagger.add(TimeDistributedDense(input_dim=input_dim, input_length=input_len, output_dim=word_proj_dim))
      tagger.add(LSTM(input_dim=word_proj_dim, output_dim=word_proj_dim, input_length=input_len, return_sequences=True))
      tagger.add(TimeDistributedDense(num_classes, activation='softmax'))
      tagger.summary()
      tagger.compile(loss='categorical_crossentropy', optimizer='adam')
      tagger.fit(X, Y, validation_split=0.1, callbacks=[early_stopping], show_accuracy=True, batch_size=10)

    return tagger
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号