def _load_backend(obj):
if hasattr(obj, '_type'):
obj._backend = type2backend[obj._type]
return
# Try to find tensor attributes and infer type from them
for key in dir(obj):
attr = getattr(obj, key)
if torch.is_tensor(attr):
try:
obj._backend = type2backend[type(attr)]
except KeyError:
pass
# Monkey patch the forward to capture the type of input
updateOutput_orig = obj.updateOutput
def updateOutput_patch(*args):
input = args[0]
while not torch.is_tensor(input):
input = input[0]
obj._backend = type2backend[type(input)]
obj.updateOutput = updateOutput_orig
return obj.updateOutput(*args)
obj.updateOutput = updateOutput_patch
评论列表
文章目录