gan.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tf_serving_example 作者: Vetal1977 项目源码 文件源码
def __init__(self, input_real, z_size, learning_rate, num_classes=10,
                 alpha=0.2, beta1=0.5, drop_rate=.5):
        """
        Initializes the GAN model.

        :param input_real: Real data for the discriminator
        :param z_size: The number of entries in the noise vector.
        :param learning_rate: The learning rate to use for Adam optimizer.
        :param num_classes: The number of classes to recognize.
        :param alpha: The slope of the left half of the leaky ReLU activation
        :param beta1: The beta1 parameter for Adam.
        :param drop_rate: RThe probability of dropping a hidden unit (used in discriminator)
        """

        self.learning_rate = tf.Variable(learning_rate, trainable=False)
        self.input_real = input_real
        self.input_z = tf.placeholder(tf.float32, (None, z_size), name='input_z')
        self.y = tf.placeholder(tf.int32, (None), name='y')
        self.label_mask = tf.placeholder(tf.int32, (None), name='label_mask')
        self.drop_rate = tf.placeholder_with_default(drop_rate, (), "drop_rate")

        loss_results = self.model_loss(self.input_real, self.input_z,
                                       self.input_real.shape[3], self.y, num_classes,
                                       label_mask=self.label_mask,
                                       drop_rate=self.drop_rate,
                                       alpha=alpha)

        self.d_loss, self.g_loss, self.correct, \
            self.masked_correct, self.samples, self.pred_class, \
                self.discriminator_class_logits, self.discriminator_out = \
                    loss_results

        self.d_opt, self.g_opt, self.shrink_lr = self.model_opt(self.d_loss,
                                                                self.g_loss,
                                                                self.learning_rate, beta1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号