def recursiveType(param, type, tensorCache={}):
from .Criterion import Criterion
from .Module import Module
if isinstance(param, list):
for i, p in enumerate(param):
param[i] = recursiveType(p, type, tensorCache)
elif isinstance(param, Module) or isinstance(param, Criterion):
param.type(type, tensorCache)
elif torch.is_tensor(param):
if torch.typename(param) != type:
key = param._cdata
if key in tensorCache:
newparam = tensorCache[key]
else:
newparam = torch.Tensor().type(type)
storageType = type.replace('Tensor', 'Storage')
param_storage = param.storage()
if param_storage:
storage_key = param_storage._cdata
if storage_key not in tensorCache:
tensorCache[storage_key] = torch._import_dotted_name(
storageType)(param_storage.size()).copy_(param_storage)
newparam.set_(
tensorCache[storage_key],
param.storage_offset(),
param.size(),
param.stride()
)
tensorCache[key] = newparam
param = newparam
return param
评论列表
文章目录