generative_models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:nli_generation 作者: jstarc 项目源码 文件源码
def baseline_train(noise_examples, hidden_size, noise_dim, glove, hypo_len, version):
    prem_input = Input(shape=(None,), dtype='int32', name='prem_input')
    hypo_input = Input(shape=(hypo_len + 1,), dtype='int32', name='hypo_input')
    noise_input = Input(shape=(1,), dtype='int32', name='noise_input')
    train_input = Input(shape=(None,), dtype='int32', name='train_input')
    class_input = Input(shape=(3,), name='class_input')
    concat_dim = hidden_size + noise_dim + 3
    prem_embeddings = make_fixed_embeddings(glove, None)(prem_input)
    hypo_embeddings = make_fixed_embeddings(glove, hypo_len + 1)(hypo_input)

    premise_layer = LSTM(output_dim=hidden_size, return_sequences=False,
                            inner_activation='sigmoid', name='premise')(prem_embeddings)

    noise_layer = Embedding(noise_examples, noise_dim,
                            input_length = 1, name='noise_embeddings')(noise_input)
    flat_noise = Flatten(name='noise_flatten')(noise_layer)    
    merged = merge([premise_layer, class_input, flat_noise], mode='concat')
    creative = Dense(concat_dim, name = 'cmerge')(merged)
    fake_merge = Lambda(lambda x:x[0], output_shape=lambda x:x[0])([hypo_embeddings, creative])
    hypo_layer = FeedLSTM(output_dim=concat_dim, return_sequences=True,
                         feed_layer = creative, inner_activation='sigmoid', 
                         name='attention')([fake_merge])

    hs = HierarchicalSoftmax(len(glove), trainable = True, name='hs')([hypo_layer, train_input])
    inputs = [prem_input, hypo_input, noise_input, train_input, class_input]


    model_name = 'version' + str(version)
    model = Model(input=inputs, output=hs, name = model_name)
    model.compile(loss=hs_categorical_crossentropy, optimizer='adam')

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号