def __init__(self,
f,
g,
num_layers=1,
f_side_input=None,
g_side_input=None,
use_efficient_backprop=True):
if isinstance(f, list):
assert len(f) == num_layers
else:
f = [f] * num_layers
if isinstance(g, list):
assert len(g) == num_layers
else:
g = [g] * num_layers
scope_prefix = "revblock/revlayer_%d/"
f_scope = scope_prefix + "f"
g_scope = scope_prefix + "g"
f = [
tf.make_template(f_scope % i, fn, create_scope_now_=True)
for i, fn in enumerate(f)
]
g = [
tf.make_template(g_scope % i, fn, create_scope_now_=True)
for i, fn in enumerate(g)
]
self.f = f
self.g = g
self.num_layers = num_layers
self.f_side_input = f_side_input or []
self.g_side_input = g_side_input or []
self._use_efficient_backprop = use_efficient_backprop
评论列表
文章目录