model.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DeepWorks 作者: daigo0927 项目源码 文件源码
def _shortcut(inputs, x): # x = f(inputs)
    # shortcut path
    _, inputs_h, inputs_w, inputs_ch = inputs.shape.as_list()
    _, x_h, x_w, x_ch = x.shape.as_list()
    stride_h = int(round(inputs_h / x_h))
    stride_w = int(round(inputs_w / x_w))
    equal_ch = inputs_ch == x_ch

    if stride_h>1 or stride_w>1 or not equal_ch:
        shortcut = tcl.conv2d(inputs,
                              num_outputs = x_ch,
                              kernel_size = (1, 1),
                              stride = (stride_h, stride_w),
                              padding = 'VALID')
    else:
        shortcut = inputs

    merged = tf.add(shortcut, x)
    return merged
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号