def _initialize_params(self):
all_weights = {}
batch_norms = {}
gen_vars = []
disc_vars = []
# init generator weights
prev_layer_dim = self.z_dim
for layer_i in xrange(len(self.generator_params['dim'])):
name = 'gen_w' + str(layer_i)
all_weights[name] = ops.variable(name,
[self.generator_params['ksize'][layer_i],
self.generator_params['ksize'][layer_i],
self.generator_params['dim'][layer_i],
prev_layer_dim],
self.initializer)
gen_vars.append(all_weights[name])
if layer_i+1==len(self.generator_params['dim']):
name = 'gen_b' + str(layer_i)
all_weights[name] = ops.variable(name,
[self.generator_params['dim'][layer_i]],
)
gen_vars.append(all_weights[name])
else:
name = 'gen_bn' + str(layer_i)
batch_norms[name] = ops.batch_norm(self.generator_params['dim'][layer_i], name=name)
prev_layer_dim = self.generator_params['dim'][layer_i]
# init discriminator weights
for disc_i in xrange(len(self.discriminators_params)):
prev_layer_dim = self.image_dim
cur_params = self.discriminators_params[disc_i]
for layer_i in xrange(len(cur_params['dim'])):
name = 'disc' + str(disc_i) + '_w' + str(layer_i)
all_weights[name] = ops.variable(name,
[cur_params['ksize'][layer_i],
cur_params['ksize'][layer_i],
prev_layer_dim,
cur_params['dim'][layer_i]],
self.initializer)
disc_vars.append(all_weights[name])
if layer_i+1==len(cur_params['dim']):
name = 'disc' + str(disc_i) + '_b' + str(layer_i)
all_weights[name] = ops.variable(name,
[cur_params['dim'][layer_i]],
tf.constant_initializer(0.0))
disc_vars.append(all_weights[name])
else:
name = 'disc_' + str(disc_i) + '_bn' + str(layer_i)
batch_norms[name] = ops.batch_norm(cur_params['dim'][layer_i], name=name)
prev_layer_dim = cur_params['dim'][layer_i]
return all_weights, batch_norms, gen_vars, disc_vars
评论列表
文章目录