question_encoding.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Constituent-Centric-Neural-Architecture-for-Reading-Comprehension 作者: shrshore 项目源码 文件源码
def process_leafs(self,inodes_h,inodes_c,emb_leaves):
        num_leaves = self.num_leaves
        embx=tf.gather(emb_leaves,tf.range(num_leaves))
        leaf_parent=tf.gather(self.t_par_leaf,tf.range(num_leaves))
        node_h=tf.identity(inodes_h)
        node_c=tf.identity(inodes_c)
        with tf.variable_scope('td_Composition',reuse=True):
            cW=tf.get_variable('cW',[self.hidden_dim+self.emb_dim,4*self.hidden_dim])
            cb=tf.get_variable('cb',[4*self.hidden_dim])
            bu,bo,bi,bf=tf.split(axis=0,num_or_size_splits=4,value=cb)
            idx_var=tf.constant(0)
            logging.warn('begin enumerate the idx_var')
            def _recurceleaf(node_h, node_c,idx_var):
                node_info=tf.gather(leaf_parent, idx_var)
                cur_embed=tf.gather(embx, idx_var)
                #initial node_h:[inode_size, dim_hidden]
                parent_h=tf.gather(node_h, node_info)
                parent_c=tf.gather(node_c, node_info)
                cur_input=tf.concat(values=[parent_h, cur_embed],axis=0)
                flat_=tf.reshape(cur_input, [-1])

                tmp=tf.matmul(tf.expand_dims(flat_,0),cW)

                u,o,i,f=tf.split(axis=1,num_or_size_splits=4,value=tmp)
                i=tf.nn.sigmoid(i+bi)
                o=tf.nn.sigmoid(o+bo)
                u=tf.nn.sigmoid(u+bu)
                f=tf.nn.sigmoid(f+bf)
                c=i*u+tf.reduce_sum(f*parent_c,[0])
                h=o*tf.nn.tanh(c)

                node_h=tf.concat(axis=0,values=[node_h,h])
                node_c=tf.concat(axis=0,values=[node_c,c])
                idx_var=tf.add(idx_var,1)
                return node_h, node_c, idx_var
            loop_cond=lambda a1,b1,idx_var:tf.less(idx_var,num_leaves)
            loop_vars=[node_h,node_c,idx_var]
            node_h,node_c,idx_var=tf.while_loop(loop_cond, _recurceleaf,loop_vars,shape_invariants=[tf.TensorShape([None,self.hidden_dim]),tf.TensorShape([None,self.hidden_dim]),idx_var.get_shape()])
            logging.warn('return new node_h, finished')
            return node_h,node_c
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号