def batch_matmul(seq, weight, nonlinearity=''):
s = None
for i in range(seq.size(0)):
_s = torch.mm(seq[i], weight)
if(nonlinearity=='tanh'):
_s = torch.tanh(_s)
_s = _s.unsqueeze(0)
if(s is None):
s = _s
else:
s = torch.cat((s,_s),0)
return s.squeeze()
评论列表
文章目录