text_model.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:text_classification 作者: senochow 项目源码 文件源码
def TextCNN(sequence_length, embedding_dim, filter_sizes, num_filters):
    ''' Convolutional Neural Network, including conv + pooling

    Args:
        sequence_length: ???????
        embedding_dim: ?????
        filter_sizes:  filter???
        num_filters: filter??

    Returns:
        features extracted by CNN
    '''
    graph_in = Input(shape=(sequence_length, embedding_dim))
    convs = []
    for fsz in filter_sizes:
        conv = Convolution1D(nb_filter=num_filters,
                         filter_length=fsz,
                         border_mode='valid',
                         activation='relu',
                         subsample_length=1)(graph_in)
        pool = MaxPooling1D()(conv)
        flatten = Flatten()(pool)
        convs.append(flatten)
    if len(filter_sizes)>1:
        out = Merge(mode='concat')(convs)
    else:
        out = convs[0]
    graph = Model(input=graph_in, output=out)
    return graph
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号