read_lua_file.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def _load_backend(obj):
    if hasattr(obj, '_type'):
        obj._backend = type2backend[obj._type]
        return
    # Try to find tensor attributes and infer type from them
    for key in dir(obj):
        attr = getattr(obj, key)
        if torch.is_tensor(attr):
            try:
                obj._backend = type2backend[type(attr)]
            except KeyError:
                pass
    # Monkey patch the forward to capture the type of input
    updateOutput_orig = obj.updateOutput

    def updateOutput_patch(*args):
        input = args[0]
        while not torch.is_tensor(input):
            input = input[0]
        obj._backend = type2backend[type(input)]
        obj.updateOutput = updateOutput_orig
        return obj.updateOutput(*args)
    obj.updateOutput = updateOutput_patch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号