def _combine_last(self, r, h_t):
'''
inputs:
r : batch x n_dim
h_t : batch x n_dim (this is the output from the gru unit)
params :
W_x : n_dim x n_dim
W_p : n_dim x n_dim
out :
h_star : batch x n_dim
'''
W_p_r = torch.mm(r, self.W_p) # batch x n_dim
W_x_h = torch.mm(h_t, self.W_x) # batch x n_dim
h_star = F.tanh(W_p_r + W_x_h) # batch x n_dim
return h_star
rte_model.py 文件源码
python
阅读 44
收藏 0
点赞 0
评论 0
评论列表
文章目录