layers.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def add_param(self, spec, shape, name, **kwargs):
        param = self.add_param_plain(spec, shape, name, **kwargs)
        if name is not None and name.startswith("W") and self.weight_normalization:
            # Hacky: check if the parameter is a weight matrix. If so, apply weight normalization
            if len(param.get_shape()) == 2:
                v = param
                g = self.add_param_plain(tf.ones_initializer, (shape[1],), name=name + "_wn/g")
                param = v * (tf.reshape(g, (1, -1)) / tf.sqrt(tf.reduce_sum(tf.square(v), 0, keep_dims=True)))
            elif len(param.get_shape()) == 4:
                v = param
                g = self.add_param_plain(tf.ones_initializer, (shape[3],), name=name + "_wn/g")
                param = v * (tf.reshape(g, (1, 1, 1, -1)) / tf.sqrt(tf.reduce_sum(tf.square(v), [0, 1, 2],
                                                                                  keep_dims=True)))
            else:
                raise NotImplementedError
        return param
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号