res_net.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def apply_shortcut(self, prev_inp, ch_in, ch_out, phase_train=None, w=None, stride=None):
        if self.shortcut == 'projection':
            if self.dilation:
                prev_inp = DilatedConv2D(w, rate=stride)(prev_inp)
            else:
                prev_inp = Conv2D(w, stride=stride)(prev_inp)
            bn = BatchNorm(ch_out)
            prev_inp = bn({'input': prev_inp, 'phase_train': phase_train})
        elif self.shortcut == 'identity':
            pad_ch = ch_out - ch_in
            if pad_ch < 0:
                raise Exception('Must use projection when ch_in > ch_out.')
            prev_inp = tf.pad(prev_inp, [[0, 0], [0, 0], [0, 0], [0, pad_ch]])
            if stride > 1:
                prev_inp = AvgPool(stride)(prev_inp)
            bn = None
        self.log.info('After proj shape: {}'.format(
            prev_inp.get_shape()))
        return prev_inp, bn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号