classify_models.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:nli_generation 作者: jstarc 项目源码 文件源码
def attention_bnorm_model(hidden_size, glove):

    prem_input = Input(shape=(None,), dtype='int32')
    hypo_input = Input(shape=(None,), dtype='int32')

    prem_embeddings = make_fixed_embeddings(glove, None)(prem_input)
    hypo_embeddings = make_fixed_embeddings(glove, None)(hypo_input)
    premise_layer = LSTM(output_dim=hidden_size, return_sequences=True,
                            inner_activation='sigmoid')(prem_embeddings)
    premise_bn = BatchNormalization()(premise_layer)
    hypo_layer = LSTM(output_dim=hidden_size, return_sequences=True,
                            inner_activation='sigmoid')(hypo_embeddings)
    hypo_bn = BatchNormalization()(hypo_layer)
    attention = LstmAttentionLayer(output_dim = hidden_size) ([hypo_bn, premise_bn])
    att_bn = BatchNormalization()(attention)
    final_dense = Dense(3, activation='softmax')(att_bn)

    model = Model(input=[prem_input, hypo_input], output=final_dense)
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号