def parallel(layer):
""" Creates a parallel operation (i.e., map/distributed operation).
"""
def func(module, x):
""" The actual wrapped operation.
"""
return torch.stack(
tuple(Layer.resolve(layer)(module, X) for X in torch.unbind(x, 0)),
0
)
func.pure = True
return func
### EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF
评论列表
文章目录