def TextCNN(sequence_length, embedding_dim, filter_sizes, num_filters):
''' Convolutional Neural Network, including conv + pooling
Args:
sequence_length: ???????
embedding_dim: ?????
filter_sizes: filter???
num_filters: filter??
Returns:
features extracted by CNN
'''
graph_in = Input(shape=(sequence_length, embedding_dim))
convs = []
for fsz in filter_sizes:
conv = Convolution1D(nb_filter=num_filters,
filter_length=fsz,
border_mode='valid',
activation='relu',
subsample_length=1)(graph_in)
pool = MaxPooling1D()(conv)
flatten = Flatten()(pool)
convs.append(flatten)
if len(filter_sizes)>1:
out = Merge(mode='concat')(convs)
else:
out = convs[0]
graph = Model(input=graph_in, output=out)
return graph
评论列表
文章目录