def _instantiate_subnet(self, batch, block_idx, seq_prefix):
def zeros_fn():
return tf.zeros_like(batch)
def base_case_fn():
return self._children[block_idx, seq_prefix](batch)
def recursive_case_fn():
first_subnet = self._instantiate_subnet(
batch, block_idx, seq_prefix + (0,))
return self._instantiate_subnet(
first_subnet, block_idx, seq_prefix + (1,))
if len(seq_prefix) == self._fractal_block_depth:
return base_case_fn()
else:
choice = self._drop_path_choices[self._choice_id[(block_idx, seq_prefix)]]
base_case = tf.cond(
tf.not_equal(choice, self._JUST_RECURSE), base_case_fn, zeros_fn)
base_case.set_shape(batch.get_shape())
recursive_case = tf.cond(
tf.not_equal(choice, self._JUST_BASE), recursive_case_fn, zeros_fn)
recursive_case.set_shape(batch.get_shape())
cases = [
(tf.equal(choice, self._BOTH),
lambda: self._mixer(base_case, recursive_case)),
(tf.equal(choice, self._JUST_BASE), lambda: base_case),
(tf.equal(choice, self._JUST_RECURSE), lambda: recursive_case)]
result = tf.case(cases, lambda: base_case)
result.set_shape(batch.get_shape())
return result
评论列表
文章目录