resnet.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:blitznet 作者: dvornikita 项目源码 文件源码
def create_segmentation_head(self, num_classes):
        """segmentation of map with stride 8 or 4, if --x4 flag is active"""
        with tf.variable_scope(DEFAULT_SSD_SCOPE) as sc:
            with slim.arg_scope([slim.conv2d],
                                kernel_size=args.seg_filter_size,
                                weights_regularizer=slim.l2_regularizer(self.weight_decay),
                                biases_initializer=tf.zeros_initializer()):

                seg_materials = []
                seg_size = self.config['fm_sizes'][0]
                for i in range(len(self.layers)):
                    target_layer = self.outputs[self.layers[i]]
                    seg = slim.conv2d(target_layer, args.n_base_channels)
                    seg = tf.image.resize_nearest_neighbor(seg, [seg_size, seg_size])
                    seg_materials.append(seg)
                seg_materials = tf.concat(seg_materials, -1)
                seg_logits = slim.conv2d(seg_materials, num_classes,
                                        kernel_size=3, activation_fn=None)
                self.outputs['segmentation'] = seg_logits
                return self.outputs['segmentation']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号