def summary_param(op, tensor, ndims, name, collections=None):
"""
Add summary as per the ops mentioned
Args:
op: name of the summary op; e.g. 'stddev'
available ops: ['scalar', 'histogram', 'sparsity', 'mean', 'rms', 'stddev', 'norm', 'max', 'min']
tensor: the tensor to add summary
ndims: dimension of the tensor
name: name of the op
collections: training or validation collections
"""
return {
'scalar': tf.summary.scalar(name, tensor, collections=collections) if ndims == 0 else tf.summary.scalar(name + '/mean', tf.reduce_mean(tensor), collections=collections),
'histogram': tf.summary.histogram(name, tensor, collections=collections) if ndims >= 2 else None,
'sparsity': tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor), collections=collections),
'mean': tf.summary.scalar(name + '/mean', tf.reduce_mean(tensor), collections=collections),
'rms': tf.summary.scalar(name + '/rms', rms(tensor), collections=collections),
'stddev': tf.summary.scalar(name + '/stddev', tf.sqrt(tf.reduce_sum(tf.square(tensor - tf.reduce_mean(tensor, name='mean_op'))), name='stddev_op'), collections=collections),
'max': tf.summary.scalar(name + '/max', tf.reduce_max(tensor), collections=collections),
'min': tf.summary.scalar(name + '/min', tf.reduce_min(tensor), collections=collections),
'norm': tf.summary.scalar(name + '/norm', tf.sqrt(tf.reduce_sum(tensor * tensor)), collections=collections),
}[op]
评论列表
文章目录