vae_skipconn.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:divcolor 作者: aditya12agd5 项目源码 文件源码
def __encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=2, reuse=False):

        lf = self.layer_factory

        input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \
                self.flags.img_height, self.flags.img_width, in_nch])

        nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value

        if(reuse==False):
            W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, nch, 128])
            W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 128, 256])
            W_conv3 = lf.weight_variable(name='W_conv3', shape=[5, 5, 256, 512])
            W_conv4 = lf.weight_variable(name='W_conv4', shape=[4, 4, 512, 1024])
            W_fc1 = lf.weight_variable(name='W_fc1', shape=[4*4*1024, self.flags.hidden_size * 2])

            b_conv1 = lf.bias_variable(name='b_conv1', shape=[128])
            b_conv2 = lf.bias_variable(name='b_conv2', shape=[256])
            b_conv3 = lf.bias_variable(name='b_conv3', shape=[512])
            b_conv4 = lf.bias_variable(name='b_conv4', shape=[1024])
            b_fc1 = lf.bias_variable(name='b_fc1', shape=[self.flags.hidden_size * 2])
        else:
            W_conv1 = lf.weight_variable(name='W_conv1')
            W_conv2 = lf.weight_variable(name='W_conv2')
            W_conv3 = lf.weight_variable(name='W_conv3')
            W_conv4 = lf.weight_variable(name='W_conv4')
            W_fc1 = lf.weight_variable(name='W_fc1')

            b_conv1 = lf.bias_variable(name='b_conv1')
            b_conv2 = lf.bias_variable(name='b_conv2')
            b_conv3 = lf.bias_variable(name='b_conv3')
            b_conv4 = lf.bias_variable(name='b_conv4')
            b_fc1 = lf.bias_variable(name='b_fc1')

        conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1)
        conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \
            'BN1', reuse_vars=reuse)

        conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2)
        conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \
            'BN2', reuse_vars=reuse)

        conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3)
        conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \
            'BN3', reuse_vars=reuse)

        conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4)
        conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \
            'BN4', reuse_vars=reuse)

        dropout1 = tf.nn.dropout(conv4_norm, keep_prob)
        flatten1 = tf.reshape(dropout1, [-1, 4*4*1024])

        fc1 = tf.matmul(flatten1, W_fc1)+b_fc1

        return fc1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号