a2_transformer_classification.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def inference(self):
        """ building blocks:
        encoder:6 layers.each layers has two   sub-layers. the first is multi-head self-attention mechanism; the second is position-wise fully connected feed-forward network.
               for each sublayer. use LayerNorm(x+Sublayer(x)). all dimension=512.
        decoder:6 layers.each layers has three sub-layers. the second layer is performs multi-head attention over the ouput of the encoder stack.
               for each sublayer. use LayerNorm(x+Sublayer(x)).
        """
        # 1.embedding for encoder input & decoder input
        # 1.1 position embedding for encoder input
        input_x_embeded = tf.nn.embedding_lookup(self.Embedding,self.input_x)  #[None,sequence_length, embed_size]
        input_x_embeded=tf.multiply(input_x_embeded,tf.sqrt(tf.cast(self.d_model,dtype=tf.float32)))
        input_mask=tf.get_variable("input_mask",[self.sequence_length,1],initializer=self.initializer)
        input_x_embeded=tf.add(input_x_embeded,input_mask) #[None,sequence_length,embed_size].position embedding.

        # 2. encoder
        encoder_class=Encoder(self.d_model,self.d_k,self.d_v,self.sequence_length,self.h,self.batch_size,self.num_layer,input_x_embeded,input_x_embeded,dropout_keep_prob=self.dropout_keep_prob,use_residual_conn=self.use_residual_conn)
        Q_encoded,K_encoded = encoder_class.encoder_fn() #K_v_encoder

        Q_encoded=tf.reshape(Q_encoded,shape=(self.batch_size,-1)) #[batch_size,sequence_length*d_model]
        with tf.variable_scope("output"):
            logits = tf.matmul(Q_encoded, self.W_projection) + self.b_projection #logits shape:[batch_size*decoder_sent_length,self.num_classes]
        print("logits:",logits)
        return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号