def build(self):
state_pre=T.zeros((self.x.shape[-1],self.n_hidden),dtype=theano.config.floatX)
def _recurrence(x_t,m,h_tm1):
x_e=self.E[x_t,:]
concated=T.concatenate([x_e,h_tm1],axis=1)
# Update gate
z_t=self.f(T.dot(concated,self.Wz) + self.bz )
# Input fate
r_t=self.f(T.dot(concated,self.Wr) + self.br )
# Cell update
c_t=T.tanh(T.dot(x_e,self.Wxc)+T.dot(r_t*h_tm1,self.Whc)+self.bc)
# Hidden state
h_t=(T.ones_like(z_t)-z_t) * c_t + z_t * h_tm1
# masking
h_t=h_t*m[:,None]
return h_t
h,_=theano.scan(fn=_recurrence,
sequences=[self.x,self.mask],
outputs_info=state_pre,
truncate_gradient=self.bptt)
# Dropout
if self.p>0:
drop_mask=self.rng.binomial(n=1,p=1-self.p,size=h.shape,dtype=theano.config.floatX)
self.activation=T.switch(self.is_train,h*drop_mask,h*(1-self.p))
else:
self.activation=T.switch(self.is_train,h,h)
评论列表
文章目录