network.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:tf_base 作者: ozansener 项目源码 文件源码
def fc(self, input, num_out, name, relu=True):
        with tf.variable_scope(name) as scope:
            input_shape = input.get_shape()
            if input_shape.ndims == 4:
                # The input is spatial. Vectorize it first.
                dim = 1
                for d in input_shape[1:].as_list():
                    dim *= d
                feed_in = tf.reshape(input, [-1, dim])
            else:
                feed_in, dim = (input, input_shape[-1].value)
            weights = self.make_var('weights', shape=[dim, num_out], init_func=tf.truncated_normal_initializer(stddev = 0.1))
            biases = self.make_var('biases', [num_out], init_func=tf.constant_initializer(0.1))
            op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
            fc = op(feed_in, weights, biases, name=scope.name)
            return fc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号