convnet.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:deep_unsupervised_posets 作者: asanakoy 项目源码 文件源码
def conv_relu(self, input_tensor, kernel_size, kernels_num, stride, batch_norm=True,
                  group=1, name=None):
        with tf.variable_scope(name) as scope:
            assert int(input_tensor.get_shape()[3]) % group == 0
            num_input_channels = int(input_tensor.get_shape()[3]) / group
            w, b = self.get_conv_weights(kernel_size, num_input_channels, kernels_num)
            conv = Convnet.conv(input_tensor, w, b, stride, padding="SAME", group=group)
            if batch_norm:
                conv = tf.cond(self.is_phase_train,
                               lambda: tflayers.batch_norm(conv,
                                                           decay=self.batch_norm_decay,
                                                           is_training=True,
                                                           trainable=True,
                                                           reuse=None,
                                                           scope=scope),
                               lambda: tflayers.batch_norm(conv,
                                                           decay=self.batch_norm_decay,
                                                           is_training=False,
                                                           trainable=True,
                                                           reuse=True,
                                                           scope=scope))
            conv = tf.nn.relu(conv, name=name)
        return conv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号