networks.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:identifiera-sarkasm 作者: risnejunior 项目源码 文件源码
def convolve_me(self, hyp, pd):
        network = input_data(shape=[None, pd.max_sequence], name='input')
        network = tflearn.embedding(network,
                                    input_dim=pd.vocab_size,
                                    output_dim=pd.emb_size,
                                    name="embedding")
        branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
        branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
        branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
        network = merge([branch1, branch2, branch3], mode='concat', axis=1)
        network = tf.expand_dims(network, 2)
        network = global_max_pool(network)
        network = dropout(network, 0.5)
        network = fully_connected(network, 2, activation='softmax')
        network = regression(network, optimizer='adam', learning_rate=0.001,
                             loss='categorical_crossentropy', name='target')
        return network
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号