def get_conv_filter(self, params):
if params["name"]+"/weights" in self.modelDict:
init = tf.constant_initializer(value=self.modelDict[params["name"]+"/weights"], dtype=tf.float32)
var = tf.get_variable(name="weights", initializer=init, shape=params["shape"])
print "loaded " + params["name"]+"/weights"
else:
if params["std"]:
stddev = params["std"]
else:
fanIn = params["shape"][0]*params["shape"][1]*params["shape"][2]
stddev = (2/float(fanIn))**0.5
init = tf.truncated_normal(shape=params["shape"], stddev=stddev, seed=0)
var = tf.get_variable(name="weights", initializer=init)
print "generated " + params["name"] + "/weights"
if not tf.get_variable_scope().reuse:
weightDecay = tf.mul(tf.nn.l2_loss(var), self._wd,
name='weight_loss')
tf.add_to_collection('losses', weightDecay)
return var
评论列表
文章目录