model.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepWorks 作者: daigo0927 项目源码 文件源码
def __call__(self, inputs, reuse = True):
        with tf.variable_scope(self.name) as vs:
            tf.get_variable_scope()
            if reuse:
                vs.reuse_variables()

            conv1 = tcl.conv2d(inputs,
                               num_outputs = 64,
                               kernel_size = (7, 7),
                               stride = (2, 2),
                               padding = 'SAME')
            conv1 = tcl.batch_norm(conv1)
            conv1 = tf.nn.relu(conv1)
            conv1 = tcl.max_pool2d(conv1,
                                   kernel_size = (3, 3),
                                   stride = (2, 2),
                                   padding = 'SAME')

            x = conv1
            filters = 64
            first_layer = True
            for i, r in enumerate(self.repetitions):
                x = _residual_block(self.block_fn,
                                    filters = filters,
                                    repetition = r,
                                    is_first_layer = first_layer)(x)
                filters *= 2
                if first_layer:
                    first_layer = False

            _, h, w, ch = x.shape.as_list()
            outputs = tcl.avg_pool2d(x,
                                     kernel_size = (h, w),
                                     stride = (1, 1))
            outputs = tcl.flatten(outputs)
            logits = tcl.fully_connected(outputs, num_outputs = self.num_output,
                                         activation_fn = None)
            return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号