test_retain.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:retain 作者: mp2893 项目源码 文件源码
def gru_layer(tparams, emb, name, hiddenDimSize):
    timesteps = emb.shape[0]
    if emb.ndim == 3: n_samples = emb.shape[1]
    else: n_samples = 1

    def stepFn(wx, h, U_gru):
        uh = T.dot(h, U_gru)
        r = T.nnet.sigmoid(_slice(wx, 0, hiddenDimSize) + _slice(uh, 0, hiddenDimSize))
        z = T.nnet.sigmoid(_slice(wx, 1, hiddenDimSize) + _slice(uh, 1, hiddenDimSize))
        h_tilde = T.tanh(_slice(wx, 2, hiddenDimSize) + r * _slice(uh, 2, hiddenDimSize))
        h_new = z * h + ((1. - z) * h_tilde)
        return h_new

    Wx = T.dot(emb, tparams['W_gru_'+name]) + tparams['b_gru_'+name]
    results, updates = theano.scan(fn=stepFn, sequences=[Wx], outputs_info=T.alloc(numpy_floatX(0.0), n_samples, hiddenDimSize), non_sequences=[tparams['U_gru_'+name]], name='gru_layer', n_steps=timesteps)

    return results
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号