gru.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:lmkit 作者: jiangnanhugo 项目源码 文件源码
def build(self):
        state_pre=T.zeros((self.x.shape[-1],self.n_hidden),dtype=theano.config.floatX)
        def _recurrence(x_t,m,h_tm1):
            x_e=self.E[x_t,:]
            concated=T.concatenate([x_e,h_tm1],axis=1)

            # Update gate
            z_t=self.f(T.dot(concated,self.Wz) + self.bz )

            # Input fate
            r_t=self.f(T.dot(concated,self.Wr) + self.br )

            # Cell update
            c_t=T.tanh(T.dot(x_e,self.Wxc)+T.dot(r_t*h_tm1,self.Whc)+self.bc)

            # Hidden state
            h_t=(T.ones_like(z_t)-z_t) * c_t + z_t * h_tm1

            # masking
            h_t=h_t*m[:,None]

            return h_t

        h,_=theano.scan(fn=_recurrence,
                        sequences=[self.x,self.mask],
                        outputs_info=state_pre,
                        truncate_gradient=self.bptt)

        # Dropout
        if self.p>0:
            drop_mask=self.rng.binomial(n=1,p=1-self.p,size=h.shape,dtype=theano.config.floatX)
            self.activation=T.switch(self.is_train,h*drop_mask,h*(1-self.p))
        else:
            self.activation=T.switch(self.is_train,h,h)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号