def affine_nd(input, weight, bias):
"""
An helper function to make applying the "wx + b" operation for
n-dimensional x easier.
Args:
input (Variable): An arbitrary input data, whose size is
(d0, d1, ..., dn, input_dim)
weight (Variable): A matrix of size (output_dim, input_dim)
bias (Variable): A bias vector of size (output_dim,)
Returns:
output: The result of size (d0, ..., dn, output_dim)
"""
input_size = input.size()
input_flat = input.view(-1, input_size[-1])
bias_expand = bias.unsqueeze(0).expand(input_flat.size(0), bias.size(0))
output_flat = torch.addmm(bias_expand, input_flat, weight)
output_size = input_size[:-1] + (weight.size(1),)
output = output_flat.view(*output_size)
return output
评论列表
文章目录