model.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:CNN-Sentence-Classifier 作者: shagunsodhani 项目源码 文件源码
def _predefined_model(args, embedding_matrix):
    '''function to use one of the predefined models (based on the paper)'''
    (filtersize_list, number_of_filters_per_filtersize, pool_length_list,
     dropout_list, optimizer, use_embeddings, embeddings_trainable) \
        = _param_selector(args)

    if (use_embeddings):
        embedding_layer = Embedding(args.nb_words + 1,
                                    args.embedding_dim,
                                    weights=[embedding_matrix],
                                    input_length=args.max_sequence_len,
                                    trainable=embeddings_trainable)
    else:
        embedding_layer = Embedding(args.nb_words + 1,
                                    args.embedding_dim,
                                    weights=None,
                                    input_length=args.max_sequence_len,
                                    trainable=embeddings_trainable)

    print('Defining model.')

    input_node = Input(shape=(args.max_sequence_len, args.embedding_dim))
    conv_list = []
    for index, filtersize in enumerate(filtersize_list):
        nb_filter = number_of_filters_per_filtersize[index]
        pool_length = pool_length_list[index]
        conv = Conv1D(nb_filter=nb_filter, filter_length=filtersize, activation='relu')(input_node)
        pool = MaxPooling1D(pool_length=pool_length)(conv)
        flatten = Flatten()(pool)
        conv_list.append(flatten)

    if (len(filtersize_list) > 1):
        out = Merge(mode='concat')(conv_list)
    else:
        out = conv_list[0]

    graph = Model(input=input_node, output=out)

    model = Sequential()
    model.add(embedding_layer)
    model.add(Dropout(dropout_list[0], input_shape=(args.max_sequence_len, args.embedding_dim)))
    model.add(graph)
    model.add(Dense(150))
    model.add(Dropout(dropout_list[1]))
    model.add(Activation('relu'))
    model.add(Dense(args.len_labels_index, activation='softmax'))
    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizer,
                  metrics=['acc'])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号