def forward(self, u, x, bias, init=None, mask_h=None):
bidir = 2 if self.bidirectional else 1
length = x.size(0) if x.dim() == 3 else 1
batch = x.size(-2)
d = self.d_out
k = u.size(-1) // d
k_ = k//2 if self.bidirectional else k
u = u.view(length, batch, d, k_)
cur = x.new(batch, d).zero_() if init is None else init
size = (length, batch, d*bidir) if x.dim() == 3 else (batch, d*bidir)
bias1, bias2 = bias.split(self.d_out)
u_ = [u.select(-1, i) for i in range(0, k_)]
h = []
x_ = x if k_ == 3 else u_[3]
for i in range(0, length):
u0i, u1i, u2i = u_[0][i], u_[1][i], u_[2][i]
g1 = torch.sigmoid(u1i + bias1)
g2 = torch.sigmoid(u2i + bias2)
cur = (cur - u0i)*g1 + u0i
if self.activation_type == 1:
val = torch.tanh(cur)
elif self.activation_type == 2:
val = torch.relu(cur)
if mask_h is not None:
val = val*mask_h
xi = x_[i]
h.append((val - xi)*g2 + xi)
if self.bidirectional:
assert False
else:
last_hidden = cur
h = torch.stack(h)
return h, last_hidden
评论列表
文章目录