train_tf.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepWorks 作者: daigo0927 项目源码 文件源码
def __call__(self, inputs, reuse = True):
        with tf.variable_scope(self.name) as vs:
            # tf.get_variable_scope()
            if reuse:
                vs.reuse_variables()

            x = tcl.conv2d(inputs,
                           num_outputs = 64,
                           kernel_size = (4, 4),
                           stride = (1, 1),
                           padding = 'SAME')
            x = tcl.batch_norm(x)
            x = tf.nn.relu(x)
            x = tcl.max_pool2d(x, (2, 2), (2, 2), 'SAME')
            x = tcl.conv2d(x,
                           num_outputs = 128,
                           kernel_size = (4, 4),
                           stride = (1, 1),
                           padding = 'SAME')
            x = tcl.batch_norm(x)
            x = tf.nn.relu(x)
            x = tcl.max_pool2d(x, (2, 2), (2, 2), 'SAME')
            x = tcl.flatten(x)
            logits = tcl.fully_connected(x, num_outputs = self.num_output)

            return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号