def _initialize_params(self):
all_weights = {}
batch_norms = {}
gen_vars = []
disc_vars = []
# init generator weights
prev_layer_dim = self.z_dim
for layer_i in xrange(len(self.generator_params['dim'])):
name = 'gen_w' + str(layer_i)
all_weights[name] = ops.variable(name,
[self.generator_params['ksize'][layer_i],
self.generator_params['ksize'][layer_i],
self.generator_params['dim'][layer_i],
prev_layer_dim],
self.initializer)
gen_vars.append(all_weights[name])
if layer_i+1==len(self.generator_params['dim']):
name = 'gen_b' + str(layer_i)
all_weights[name] = ops.variable(name,
[self.generator_params['dim'][layer_i]],
)
gen_vars.append(all_weights[name])
else:
name = 'gen_bn' + str(layer_i)
batch_norms[name] = ops.batch_norm(self.generator_params['dim'][layer_i], name=name)
prev_layer_dim = self.generator_params['dim'][layer_i]
# init discriminator weights
prev_layer_dim = self.image_dim
for layer_i in xrange(len(self.discriminator_params['dim'])):
name = 'disc_w' + str(layer_i)
all_weights[name] = ops.variable(name,
[self.discriminator_params['ksize'][layer_i],
self.discriminator_params['ksize'][layer_i],
prev_layer_dim,
self.discriminator_params['dim'][layer_i]],
self.initializer)
disc_vars.append(all_weights[name])
if layer_i+1==len(self.discriminator_params['dim']):
name = 'disc_b' + str(layer_i)
all_weights[name] = ops.variable(name,
[self.discriminator_params['dim'][layer_i]],
tf.constant_initializer(0.0))
disc_vars.append(all_weights[name])
else:
name = 'disc_bn' + str(layer_i)
batch_norms[name] = ops.batch_norm(self.discriminator_params['dim'][layer_i], name=name)
prev_layer_dim = self.discriminator_params['dim'][layer_i]
return all_weights, batch_norms, gen_vars, disc_vars
评论列表
文章目录