def print_training_params(self, cgs, training_params):
enc_dec_param_dict = merge(self.encoder.get_params(),
self.decoder.get_params())
# Print which parameters are excluded
for k, v in cgs.iteritems():
excluded_all = list(set(v.parameters) - set(training_params[k]))
for p in excluded_all:
logger.info(
'Excluding from training of CG[{}]: {}'
.format(k, [key for key, val in
enc_dec_param_dict.iteritems()
if val == p][0]))
logger.info(
'Total number of excluded parameters for CG[{}]: [{}]'
.format(k, len(excluded_all)))
for k, v in training_params.iteritems():
for p in v:
logger.info('Training parameter from CG[{}]: {}'
.format(k, p.name))
logger.info(
'Total number of parameters will be trained for CG[{}]: [{}]'
.format(k, len(v)))
评论列表
文章目录