def layer(func):
class Layer(object):
def __init__(self, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
self.name = self.kwargs.get("name", self.func.__name__)
self._template = tf.make_template(self.name, self.func, create_scope_now_=True)
self._unique_name = self._template.variable_scope.name.split("/")[-1]
self._summary_added = False
def __call__(self, x):
out = self.template(x, *self.args, **self.kwargs)
self._layer_logging(x, out)
self._add_summary()
return out
def __rrshift__(self, other):
""" >> """
return self.__call__(other)
def _layer_logging(self, other, out):
tf.logging.info(" {} {} {} -> {}".format(
self.unique_name, "shape", str(other.get_shape()), str(out.get_shape())))
def _add_summary(self):
if not self.kwargs.get("summary"):
return None
if self.summary_added:
return None
for var in self.get_variables_in_scope():
# TODO: different summary types
tf.summary.scalar(var.name, tf.reduce_mean(var))
self._summary_added = True
def get_variables_in_scope(self):
assert self.template._variables_created, "Variables not yet created or undefined."
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.variable_scope_name)
return variables
@property
def template(self):
return self._template
@property
def unique_name(self):
return self._unique_name
@property
def variable_scope_name(self):
return self.template._variable_scope._name
@property
def summary_added(self):
return self._summary_added
return Layer
评论列表
文章目录