layers.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:text-gan-tensorflow 作者: tokestermw 项目源码 文件源码
def layer(func):

    class Layer(object):
        def __init__(self, *args, **kwargs):
            self.func = func
            self.args = args
            self.kwargs = kwargs
            self.name = self.kwargs.get("name", self.func.__name__)

            self._template = tf.make_template(self.name, self.func, create_scope_now_=True)
            self._unique_name = self._template.variable_scope.name.split("/")[-1]
            self._summary_added = False

        def __call__(self, x):
            out = self.template(x, *self.args, **self.kwargs)
            self._layer_logging(x, out)
            self._add_summary()
            return out

        def __rrshift__(self, other):
            """ >> """
            return self.__call__(other)

        def _layer_logging(self, other, out):
            tf.logging.info("     {} {} {} -> {}".format(
                self.unique_name, "shape", str(other.get_shape()), str(out.get_shape())))

        def _add_summary(self):
            if not self.kwargs.get("summary"):
                return None
            if self.summary_added:
                return None
            for var in self.get_variables_in_scope():
                # TODO: different summary types
                tf.summary.scalar(var.name, tf.reduce_mean(var))
            self._summary_added = True

        def get_variables_in_scope(self):
            assert self.template._variables_created, "Variables not yet created or undefined."
            variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.variable_scope_name)
            return variables

        @property
        def template(self):
            return self._template

        @property
        def unique_name(self):
            return self._unique_name

        @property
        def variable_scope_name(self):
            return self.template._variable_scope._name

        @property
        def summary_added(self):
            return self._summary_added

    return Layer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号