models.py 文件源码

python
阅读 62 收藏 0 点赞 0 评论 0

项目:NeuralNetwork-ImageQA 作者: ayushoriginal 项目源码 文件源码
def vis_lstm_2():
    embedding_matrix = embedding.load()
    embedding_model = Sequential()
    embedding_model.add(Embedding(
        embedding_matrix.shape[0],
        embedding_matrix.shape[1],
        weights = [embedding_matrix],
        trainable = False))

    image_model_1 = Sequential()
    image_model_1.add(Dense(
        embedding_matrix.shape[1],
        input_dim=4096,
        activation='linear'))
    image_model_1.add(Reshape((1,embedding_matrix.shape[1])))

    image_model_2 = Sequential()
    image_model_2.add(Dense(
        embedding_matrix.shape[1],
        input_dim=4096,
        activation='linear'))
    image_model_2.add(Reshape((1,embedding_matrix.shape[1])))

    main_model = Sequential()
    main_model.add(Merge(
        [image_model_1,embedding_model,image_model_2],
        mode = 'concat',
        concat_axis = 1))
    main_model.add(LSTM(1001))
    main_model.add(Dropout(0.5))
    main_model.add(Dense(1001,activation='softmax'))

    return main_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号