BaseModel.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def _get_weight_variable(self, layer_name, name, shape, L2=1):
        wname = '%s/%s:0'%(layer_name,name)
        fanin, fanout = shape[-2:]
        for dim in shape[:-2]:
            fanin *= float(dim)
            fanout *= float(dim)

        sigma = self._xavi_norm(fanin, fanout)
        if self.weights is None or wname not in self.weights:
            w1 = tf.get_variable(name,initializer=tf.truncated_normal(shape = shape,
                mean=0,stddev = sigma))
            print('{:>23} {:>23}'.format(wname, 'randomly initialize'))
        else:
            w1 = tf.get_variable(name, shape = shape,
                initializer=tf.constant_initializer(value=self.weights[wname],dtype=tf.float32))
            self.loaded_weights[wname]=1
        if wname != w1.name:
            print(wname,w1.name)
            assert False
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, tf.nn.l2_loss(w1)*L2)
        return w1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号