def apply(self, nn, nodes):
"""Apply current fold to given neural module."""
values = {}
for step in sorted(self.steps.keys()):
values[step] = {}
for op in self.steps[step]:
func = getattr(nn, op)
try:
batched_args = self._batch_args(
zip(*self.steps[step][op]), values)
except Exception:
print("Error while executing node %s[%d] with args: %s" % (
op, step, self.steps[step][op]))
raise
if batched_args:
arg_size = batched_args[0].size()[0]
else:
arg_size = 1
res = func(*batched_args)
if isinstance(res, (tuple, list)):
values[step][op] = []
for x in res:
values[step][op].append(torch.chunk(x, arg_size))
else:
values[step][op] = torch.chunk(res, arg_size)
try:
return self._batch_args(nodes, values)
except Exception:
print("Retrieving %s" % nodes)
for lst in nodes:
if isinstance(lst[0], Fold.Node):
print(', '.join([str(x.get(values).size()) for x in lst]))
raise
评论列表
文章目录