base.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TensorBase 作者: dancsalo 项目源码 文件源码
def flatten(self, keep_prob=1):
        """
        Flattens 4D Tensor (from Conv Layer) into 2D Tensor (to FC Layer)
        :param keep_prob: int. set to 1 for no dropout
        """
        self.count['flat'] += 1
        scope = 'flat_' + str(self.count['flat'])
        with tf.variable_scope(scope):
            # Reshape function
            input_nodes = tf.Dimension(
                self.input.get_shape()[1] * self.input.get_shape()[2] * self.input.get_shape()[3])
            output_shape = tf.stack([-1, input_nodes])
            self.input = tf.reshape(self.input, output_shape)

            # Dropout function
            if keep_prob != 1:
                self.input = tf.nn.dropout(self.input, keep_prob=keep_prob)
        print(scope + ' output: ' + str(self.input.get_shape()))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号