model_tmmd.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:opt-mmd 作者: dougalsutherland 项目源码 文件源码
def generator(self, z, y=None, is_train=True, reuse=False):
        if reuse:
            tf.get_variable_scope().reuse_variables()

        s = self.output_size
        if np.mod(s, 16) == 0:
            s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)

            # project `z` and reshape
            self.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*8*s16*s16, 'g_h0_lin', with_w=True)

            self.h0 = tf.reshape(self.z_, [-1, s16, s16, self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(self.h0, train=is_train))

            self.h1, self.h1_w, self.h1_b = deconv2d(h0,
                [self.batch_size, s8, s8, self.gf_dim*4], name='g_h1', with_w=True)
            h1 = tf.nn.relu(self.g_bn1(self.h1, train=is_train))

            h2, self.h2_w, self.h2_b = deconv2d(h1,
                [self.batch_size, s4, s4, self.gf_dim*2], name='g_h2', with_w=True)
            h2 = tf.nn.relu(self.g_bn2(h2, train=is_train))

            h3, self.h3_w, self.h3_b = deconv2d(h2,
                [self.batch_size, s2, s2, self.gf_dim*1], name='g_h3', with_w=True)
            h3 = tf.nn.relu(self.g_bn3(h3, train=is_train))

            h4, self.h4_w, self.h4_b = deconv2d(h3,
                [self.batch_size, s, s, self.c_dim], name='g_h4', with_w=True)
            return tf.nn.tanh(h4)
        else:
            s = self.output_size
            s2, s4 = int(s/2), int(s/4)
            self.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*2*s4*s4, 'g_h0_lin', with_w=True)

            self.h0 = tf.reshape(self.z_, [-1, s4, s4, self.gf_dim * 2])
            h0 = tf.nn.relu(self.g_bn0(self.h0, train=is_train))

            self.h1, self.h1_w, self.h1_b = deconv2d(h0,
                [self.batch_size, s2, s2, self.gf_dim*1], name='g_h1', with_w=True)
            h1 = tf.nn.relu(self.g_bn1(self.h1, train=is_train))

            h2, self.h2_w, self.h2_b = deconv2d(h1,
                [self.batch_size, s, s, self.c_dim], name='g_h2', with_w=True)

            return tf.nn.tanh(h2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号