def get_number_of_trainable_parameters():
""" use default graph """
# https://stackoverflow.com/questions/38160940/ ...
LOGGER.debug('Now compute total number of trainable params...')
total_parameters = 0
for variable in tf.trainable_variables():
shape = variable.get_shape()
name = variable.name
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
LOGGER.debug(' layer name = {}, shape = {}, n_params = {}'.format(
name, shape, variable_parameters
))
total_parameters += variable_parameters
LOGGER.debug('Total parameters = %d' % total_parameters)
return total_parameters
评论列表
文章目录