drmm_tks.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:MatchZoo 作者: faneshion 项目源码 文件源码
def build(self):
        query = Input(name='query', shape=(self.config['text1_maxlen'],))
        show_layer_info('Input', query)
        doc = Input(name='doc', shape=(self.config['text2_maxlen'],))
        show_layer_info('Input', doc)

        embedding = Embedding(self.config['vocab_size'], self.config['embed_size'], weights=[self.config['embed']], trainable=self.embed_trainable)
        q_embed = embedding(query)
        show_layer_info('Embedding', q_embed)
        d_embed = embedding(doc)
        show_layer_info('Embedding', d_embed)
        mm = Dot(axes=[2, 2], normalize=True)([q_embed, d_embed])
        show_layer_info('Dot', mm)

        # compute term gating
        w_g = Dense(1)(q_embed)
        show_layer_info('Dense', w_g)
        g = Lambda(lambda x: softmax(x, axis=1), output_shape=(self.config['text1_maxlen'], ))(w_g)
        show_layer_info('Lambda-softmax', g)
        g = Reshape((self.config['text1_maxlen'],))(g)
        show_layer_info('Reshape', g)

        mm_k = Lambda(lambda x: K.tf.nn.top_k(x, k=self.config['topk'], sorted=True)[0])(mm)
        show_layer_info('Lambda-topk', mm_k)

        for i in range(self.config['num_layers']):
            mm_k = Dense(self.config['hidden_sizes'][i], activation='softplus', kernel_initializer='he_uniform', bias_initializer='zeros')(mm_k)
            show_layer_info('Dense', mm_k)

        mm_k_dropout = Dropout(rate=self.config['dropout_rate'])(mm_k)
        show_layer_info('Dropout', mm_k_dropout)

        mm_reshape = Reshape((self.config['text1_maxlen'],))(mm_k_dropout)
        show_layer_info('Reshape', mm_reshape)

        mean = Dot(axes=[1, 1])([mm_reshape, g])
        show_layer_info('Dot', mean)

        if self.config['target_mode'] == 'classification':
            out_ = Dense(2, activation='softmax')(mean)
        elif self.config['target_mode'] in ['regression', 'ranking']:
            out_ = Reshape((1,))(mean)
        show_layer_info('Dense', out_)

        model = Model(inputs=[query, doc], outputs=out_)
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号