def batch_matmul_bias(seq, weight, bias, nonlinearity=''):
s = None
bias_dim = bias.size()
for i in range(seq.size(0)):
_s = torch.mm(seq[i], weight)
_s_bias = _s + bias.expand(bias_dim[0], _s.size()[0]).transpose(0,1)
if(nonlinearity=='tanh'):
_s_bias = torch.tanh(_s_bias)
_s_bias = _s_bias.unsqueeze(0)
if(s is None):
s = _s_bias
else:
s = torch.cat((s,_s_bias),0)
return s.squeeze()
评论列表
文章目录