base.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:TensorBase 作者: dancsalo 项目源码 文件源码
def fc(self, output_nodes, keep_prob=1, activation_fn=tf.nn.relu, b_value=0.0, s_value=1.0, bn=True,
           trainable=True):
        """
        Fully Connected Layer
        :param output_nodes: int
        :param keep_prob: int. set to 1 for no dropout
        :param activation_fn: tf.nn function
        :param b_value: float or None
        :param s_value: float or None
        :param bn: bool
        """
        self.count['fc'] += 1
        scope = 'fc_' + str(self.count['fc'])
        with tf.variable_scope(scope):

            # Flatten if necessary
            if len(self.input.get_shape()) == 4:
                input_nodes = tf.Dimension(
                    self.input.get_shape()[1] * self.input.get_shape()[2] * self.input.get_shape()[3])
                output_shape = tf.stack([-1, input_nodes])
                self.input = tf.reshape(self.input, output_shape)

            # Matrix Multiplication Function
            input_nodes = self.input.get_shape()[1]
            output_shape = [input_nodes, output_nodes]
            w = self.weight_variable(name='weights', shape=output_shape, trainable=trainable)
            self.input = tf.matmul(self.input, w)

            if bn is True:  # batch normalization
                self.input = self.batch_norm(self.input, 'fc')
            if b_value is not None:  # bias value
                b = self.const_variable(name='bias', shape=[output_nodes], value=b_value, trainable=trainable)
                self.input = tf.add(self.input, b)
            if s_value is not None:  # scale value
                s = self.const_variable(name='scale', shape=[output_nodes], value=s_value, trainable=trainable)
                self.input = tf.multiply(self.input, s)
            if activation_fn is not None:  # activation function
                self.input = activation_fn(self.input)
            if keep_prob != 1:  # dropout function
                self.input = tf.nn.dropout(self.input, keep_prob=keep_prob)
        print(scope + ' output: ' + str(self.input.get_shape()))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号