depth_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:dwt 作者: min2209 项目源码 文件源码
def get_conv_filter(self, params):
        if params["name"]+"/weights" in self.modelDict:
            init = tf.constant_initializer(value=self.modelDict[params["name"]+"/weights"], dtype=tf.float32)
            var = tf.get_variable(name="weights", initializer=init, shape=params["shape"])
            print "loaded " + params["name"]+"/weights"
        else:
            if params["std"]:
                stddev = params["std"]
            else:
                fanIn = params["shape"][0]*params["shape"][1]*params["shape"][2]
                stddev = (2/float(fanIn))**0.5

            init = tf.truncated_normal(shape=params["shape"], stddev=stddev, seed=0)
            var = tf.get_variable(name="weights", initializer=init)
            print "generated " + params["name"] + "/weights"

        if not tf.get_variable_scope().reuse:
            weightDecay = tf.mul(tf.nn.l2_loss(var), self._wd,
                                  name='weight_loss')
            tf.add_to_collection('losses', weightDecay)

        return var
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号