linear.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:static-define-by-run 作者: bkvogel 项目源码 文件源码
def forward(self, inputs):
        # todo: This is only compatible with Numpy. Not yet compatible with cupy.
        x = inputs[0]
        W = inputs[1]
        # Notes:
        # In order to be compatible with the "static graph" feature, it is
        # required that all output arrays of this forward
        # function be allocated explicitly:
        y = np.empty((x.shape[0], W.shape[0])).astype(x.dtype)
        # This is required because all of the "static_*()" functions
        # use the convention that any output arrays are supplied
        # as input arguments to the function. That is because it is
        # not allowed for a "static_*()" function to return anything
        # other than `None`. The reason is to prevent dynamic allocation
        # of output arrays during execution of the static schedule
        # because it would break the model.
        if len(inputs) == 3:
            bias = inputs[2]
            # Note: `y` is the output array.
            self.static_linear(x, W, bias, y)
        else:
            # Note: `y` is the output array.
            self.static_linear_no_bias(x, W, y)
        return y,
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号