a8_dynamic_memory_network.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def attention_mechanism_parallel(self,c_full,m,q,i):
        """ parallel implemtation of gate function given a list of candidate sentence, a query, and previous memory.
        Input:
           c_full: candidate fact. shape:[batch_size,story_length,hidden_size]
           m: previous memory. shape:[batch_size,hidden_size]
           q: question. shape:[batch_size,hidden_size]
        Output: a scalar score (in batch). shape:[batch_size,story_length]
        """
        q=tf.expand_dims(q,axis=1) #[batch_size,1,hidden_size]
        m=tf.expand_dims(m,axis=1) #[batch_size,1,hidden_size]

        # 1.define a large feature vector that captures a variety of similarities between input,memory and question vector: z(c,m,q)
        c_q_elementwise=tf.multiply(c_full,q)          #[batch_size,story_length,hidden_size]
        c_m_elementwise=tf.multiply(c_full,m)          #[batch_size,story_length,hidden_size]
        c_q_minus=tf.abs(tf.subtract(c_full,q))        #[batch_size,story_length,hidden_size]
        c_m_minus=tf.abs(tf.subtract(c_full,m))        #[batch_size,story_length,hidden_size]
        # c_transpose Wq
        c_w_q=self.x1Wx2_parallel(c_full,q,"c_w_q"+str(i))   #[batch_size,story_length,hidden_size]
        c_w_m=self.x1Wx2_parallel(c_full,m,"c_w_m"+str(i))   #[batch_size,story_length,hidden_size]
        # c_transposeWm
        q_tile=tf.tile(q,[1,self.story_length,1])     #[batch_size,story_length,hidden_size]
        m_tile=tf.tile(m,[1,self.story_length,1])     #[batch_size,story_length,hidden_size]
        z=tf.concat([c_full,m_tile,q_tile,c_q_elementwise,c_m_elementwise,c_q_minus,c_m_minus,c_w_q,c_w_m],2) #[batch_size,story_length,hidden_size*9]
        # 2. two layer feed foward
        g=tf.layers.dense(z,self.hidden_size*3,activation=tf.nn.tanh)  #[batch_size,story_length,hidden_size*3]
        g=tf.layers.dense(g,1,activation=tf.nn.sigmoid)                #[batch_size,story_length,1]
        g=tf.squeeze(g,axis=2)                                         #[batch_size,story_length]
        return g
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号