cascgan_model.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:gan_tensorflow 作者: dantkz 项目源码 文件源码
def _initialize_params(self):
        all_weights = {}
        batch_norms = {}
        gen_vars = []
        disc_vars = []

        # init generator weights
        prev_layer_dim = self.z_dim
        for layer_i in xrange(len(self.generator_params['dim'])):
            name = 'gen_w' + str(layer_i)
            all_weights[name] = ops.variable(name, 
                  [self.generator_params['ksize'][layer_i], 
                   self.generator_params['ksize'][layer_i], 
                   self.generator_params['dim'][layer_i], 
                   prev_layer_dim], 
                self.initializer)
            gen_vars.append(all_weights[name])

            if layer_i+1==len(self.generator_params['dim']):
                name = 'gen_b' + str(layer_i)
                all_weights[name] = ops.variable(name, 
                    [self.generator_params['dim'][layer_i]], 
                    )
                gen_vars.append(all_weights[name])
            else:
                name = 'gen_bn' + str(layer_i)
                batch_norms[name] = ops.batch_norm(self.generator_params['dim'][layer_i], name=name)

            prev_layer_dim = self.generator_params['dim'][layer_i]

        # init discriminator weights
        for disc_i in xrange(len(self.discriminators_params)):
            prev_layer_dim = self.image_dim
            cur_params = self.discriminators_params[disc_i]
            for layer_i in xrange(len(cur_params['dim'])):
                name = 'disc' + str(disc_i) + '_w' + str(layer_i)
                all_weights[name] = ops.variable(name, 
                      [cur_params['ksize'][layer_i], 
                       cur_params['ksize'][layer_i], 
                       prev_layer_dim,
                       cur_params['dim'][layer_i]],
                    self.initializer)

                disc_vars.append(all_weights[name])

                if layer_i+1==len(cur_params['dim']):
                    name = 'disc' + str(disc_i) + '_b' + str(layer_i)
                    all_weights[name] = ops.variable(name, 
                            [cur_params['dim'][layer_i]], 
                            tf.constant_initializer(0.0))
                    disc_vars.append(all_weights[name])
                else:
                    name = 'disc_' + str(disc_i) + '_bn' + str(layer_i)
                    batch_norms[name] = ops.batch_norm(cur_params['dim'][layer_i], name=name)

                prev_layer_dim = cur_params['dim'][layer_i]

        return all_weights, batch_norms, gen_vars, disc_vars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号