rnnlayer.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:recurrent-attention-for-QA-SQUAD-based-on-keras 作者: wentaozhu 项目源码 文件源码
def step(self, inputs, states):
        h_tm1 = states[0]  # previous memory
        #B_U = states[1]  # dropout matrices for recurrent units
        #B_W = states[2]
        h_tm1a = K.dot(h_tm1, self.Wa)
        eij = K.dot(K.tanh(h_tm1a + K.dot(inputs[:, :self.h_dim], self.Ua)), self.Va)
        eijs = K.repeat_elements(eij, self.h_dim, axis=1)

        #alphaij = K.softmax(eijs) # batchsize * lenh       h batchsize * lenh * ndim
        #ci = K.permute_dimensions(K.permute_dimensions(self.h, [2,0,1]) * alphaij, [1,2,0])
        #cisum = K.sum(ci, axis=1)
        cisum = eijs*inputs[:, :self.h_dim]
        #print(K.shape(cisum), cisum.shape, ci.shape, self.h.shape, alphaij.shape, x.shape)

        zr = K.sigmoid(K.dot(inputs[:, self.h_dim:], self.Wzr) + K.dot(h_tm1, self.Uzr) + K.dot(cisum, self.Czr))
        zi = zr[:, :self.units]
        ri = zr[:, self.units: 2 * self.units]
        si_ = K.tanh(K.dot(inputs[:, self.h_dim:], self.W) + K.dot(ri*h_tm1, self.U) + K.dot(cisum, self.C))
        si = (1-zi) * h_tm1 + zi * si_
        return si, [si] #h_tm1, [h_tm1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号