layers.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def _instantiate_subnet(self, batch, block_idx, seq_prefix):
    def zeros_fn():
      return tf.zeros_like(batch)
    def base_case_fn():
      return self._children[block_idx, seq_prefix](batch)
    def recursive_case_fn():
      first_subnet = self._instantiate_subnet(
          batch, block_idx, seq_prefix + (0,))
      return self._instantiate_subnet(
          first_subnet, block_idx, seq_prefix + (1,))
    if len(seq_prefix) == self._fractal_block_depth:
      return base_case_fn()
    else:
      choice = self._drop_path_choices[self._choice_id[(block_idx, seq_prefix)]]
      base_case = tf.cond(
          tf.not_equal(choice, self._JUST_RECURSE), base_case_fn, zeros_fn)
      base_case.set_shape(batch.get_shape())
      recursive_case = tf.cond(
          tf.not_equal(choice, self._JUST_BASE), recursive_case_fn, zeros_fn)
      recursive_case.set_shape(batch.get_shape())
      cases = [
          (tf.equal(choice, self._BOTH),
           lambda: self._mixer(base_case, recursive_case)),
          (tf.equal(choice, self._JUST_BASE), lambda: base_case),
          (tf.equal(choice, self._JUST_RECURSE), lambda: recursive_case)]
      result = tf.case(cases, lambda: base_case)
      result.set_shape(batch.get_shape())
      return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号