def Pack_FwGrad(*args, **kwargs):
dx = args[1:]
axis = kwargs["axis"]
if all(map(lambda x: x is None, dx)):
log.error("hey")
return None
else:
### Here we need to fill in zeros.
def _mapper(_):
dx = _[0]
x = _[1]
return dx if dx is not None else tf.zeros_like(x)
dx = list(map(_mapper, zip(dx, list(args[0].inputs))))
if tf.__version__.startswith("0"):
return tf.pack(dx, axis=axis)
else:
return tf.stack(dx, axis=axis)
评论列表
文章目录