cnn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:sota_sentiment 作者: jbarnesspain 项目源码 文件源码
def create_cnn(W, max_length, dim=300,
               dropout=.5, output_dim=8):

    # Convolutional model
    filter_sizes=(2,3,4)
    num_filters = 3


    graph_in = Input(shape=(max_length, len(W[0])))
    convs = []
    for fsz in filter_sizes:
        conv = Convolution1D(nb_filter=num_filters,
                 filter_length=fsz,
                 border_mode='valid',
                 activation='relu',
                 subsample_length=1)(graph_in)
        pool = MaxPooling1D(pool_length=2)(conv)
        flatten = Flatten()(pool)
        convs.append(flatten)

    out = Merge(mode='concat')(convs)
    graph = Model(input=graph_in, output=out)

    # Full model
    model = Sequential()
    model.add(Embedding(output_dim=W.shape[1],
                        input_dim=W.shape[0],
                        input_length=max_length, weights=[W],
                        trainable=True))
    model.add(Dropout(dropout))
    model.add(graph)
    model.add(Dense(dim, activation='relu'))
    model.add(Dropout(dropout))
    model.add(Dense(output_dim, activation='softmax'))
    if output_dim == 2:
        model.compile('adam', 'binary_crossentropy',
                  metrics=['accuracy'])
    else:
        model.compile('adam', 'categorical_crossentropy',
                  metrics=['accuracy'])
    return model

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号