basic.py 文件源码

python
阅读 47 收藏 0 点赞 0 评论 0

项目:unsupervised-treelstm 作者: jihunchoi 项目源码 文件源码
def affine_nd(input, weight, bias):
    """
    An helper function to make applying the "wx + b" operation for
    n-dimensional x easier.

    Args:
        input (Variable): An arbitrary input data, whose size is
            (d0, d1, ..., dn, input_dim)
        weight (Variable): A matrix of size (output_dim, input_dim)
        bias (Variable): A bias vector of size (output_dim,)

    Returns:
        output: The result of size (d0, ..., dn, output_dim)
    """

    input_size = input.size()
    input_flat = input.view(-1, input_size[-1])
    bias_expand = bias.unsqueeze(0).expand(input_flat.size(0), bias.size(0))
    output_flat = torch.addmm(bias_expand, input_flat, weight)
    output_size = input_size[:-1] + (weight.size(1),)
    output = output_flat.view(*output_size)
    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号