resnet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:various_residual_networks 作者: yuhui-lin 项目源码 文件源码
def resnn(self, image_batch):
        """Build the resnn model.
        Args:
            image_batch: Sequences returned from inputs_train() or inputs_eval.
        Returns:
            Logits.
        """
        # First convolution
        with tf.variable_scope('conv_layer1'):
            net = self.conv2d(image_batch, self.groups[0].num_ker, 5, 1)
            net = self.BN_ReLU(net)

        # Max pool
        if FLAGS.max_pool:
            net = tf.nn.max_pool(net,
                                 [1, 3, 3, 1],
                                 strides=[1, 1, 1, 1],
                                 padding='SAME')

        # stacking Residual Units
        for group_i, group in enumerate(self.groups):
            for unit_i in range(group.num_units):
                net = self.residual_unit(net, group_i, unit_i)

        # an extra activation before average pooling
        if FLAGS.special_first:
            with tf.variable_scope('special_BN_ReLU'):
                net = self.BN_ReLU(net)

        # padding should be VALID for global average pooling
        # output: batch*1*1*channels
        net_shape = net.get_shape().as_list()
        net = tf.nn.avg_pool(net,
                             ksize=[1, net_shape[1], net_shape[2], 1],
                             strides=[1, 1, 1, 1],
                             padding='VALID')

        net_shape = net.get_shape().as_list()
        softmax_len = net_shape[1] * net_shape[2] * net_shape[3]
        net = tf.reshape(net, [-1, softmax_len])

        # add dropout
        if FLAGS.dropout:
            with tf.name_scope("dropout"):
                net = tf.nn.dropout(net, FLAGS.dropout_keep_prob)

        # 2D-fully connected nueral network
        with tf.variable_scope('FC-layer'):
            net = fully_connected(
                net,
                num_outputs=FLAGS.num_cats,
                activation_fn=None,
                normalizer_fn=None,
                weights_initializer=variance_scaling_initializer(),
                weights_regularizer=l2_regularizer(FLAGS.weight_decay),
                biases_initializer=tf.zeros_initializer, )

        return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号