network_base.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tf-openpose 作者: ildoonet 项目源码 文件源码
def separable_conv(self, input, k_h, k_w, c_o, stride, name, relu=True):
        with slim.arg_scope([slim.batch_norm], fused=common.batchnorm_fused):
            output = slim.separable_convolution2d(input,
                                                  num_outputs=None,
                                                  stride=stride,
                                                  trainable=self.trainable,
                                                  depth_multiplier=1.0,
                                                  kernel_size=[k_h, k_w],
                                                  activation_fn=None,
                                                  weights_initializer=tf.contrib.layers.xavier_initializer(),
                                                  # weights_initializer=tf.truncated_normal_initializer(stddev=0.09),
                                                  weights_regularizer=tf.contrib.layers.l2_regularizer(0.00004),
                                                  biases_initializer=None,
                                                  padding=DEFAULT_PADDING,
                                                  scope=name + '_depthwise')

            output = slim.convolution2d(output,
                                        c_o,
                                        stride=1,
                                        kernel_size=[1, 1],
                                        activation_fn=tf.nn.relu if relu else None,
                                        weights_initializer=tf.contrib.layers.xavier_initializer(),
                                        # weights_initializer=tf.truncated_normal_initializer(stddev=0.09),
                                        biases_initializer=slim.init_ops.zeros_initializer(),
                                        normalizer_fn=slim.batch_norm,
                                        trainable=self.trainable,
                                        weights_regularizer=tf.contrib.layers.l2_regularizer(common.regularizer_dsconv),
                                        # weights_regularizer=None,
                                        scope=name + '_pointwise')

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号