def apply_shortcut(self, prev_inp, ch_in, ch_out, phase_train=None, w=None,
bn=None, stride=None):
if self.shortcut == 'projection':
if self.dilation:
prev_inp = DilatedConv2D(w, rate=stride)(prev_inp)
else:
prev_inp = Conv2D(w, stride=stride)(prev_inp)
prev_inp = bn({'input': prev_inp, 'phase_train': phase_train})
elif self.shortcut == 'identity':
pad_ch = ch_out - ch_in
if pad_ch < 0:
raise Exception('Must use projection when ch_in > ch_out.')
prev_inp = tf.pad(prev_inp, [[0, 0], [0, 0], [0, 0], [0, pad_ch]])
if stride > 1:
prev_inp = AvgPool(stride)(prev_inp)
raise Exception('DEBUG Unknown')
self.log.info('After proj shape: {}'.format(
prev_inp.get_shape()))
return prev_inp
评论列表
文章目录