def __call__(self, x, deterministic, train_clip=False, thresh=3):
# Alpha is the dropout rate
log_alpha = clip(self.log_sigma2 - tf.log(self.W**2 + eps))
# Values of log_alpha that are above the threshold
clip_mask = tf.greater_equal(log_alpha, thresh)
def true_path(): # For inference
# If log_alpha >= thresh, return 0
# If log_alpha < thresh, return tf.matmul(x,self.W)
return tf.matmul(x, tf.where(clip_mask, tf.zeros_like(self.W), self.W))
def false_path(): # For training
# Sample from a normal distribution centred on tf.matmul(x,W)
# and with variance roughly proportional to the size of tf.matmul(x,W)*tf.exp(log_alpha)
W = self.W
if train_clip:
raise NotImplementedError
mu = tf.matmul(x,W)
si = tf.matmul(x*x, tf.exp(log_alpha) * self.W * self.W)
si = tf.sqrt(si + eps)
return mu + tf.random_normal(tf.shape(mu), mean=0.0, stddev=1.0) * si
h = tf.cond(deterministic, true_path, false_path)
return self.nonlinearity(h + self.b)
评论列表
文章目录