def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
W_shape, _ = self._weight_shapes(self.n_categories)
n_batch = tf.shape(X)[1]
# Layer weights
self.pW = _make_prior(self.std, self.pW, W_shape)
self.qW = _make_posterior(self.std, self.qW, W_shape, self.full)
# Index into the relevant weights rather than using sparse matmul
Wsamples = _sample_W(self.qW, n_samples)
features = tf.map_fn(lambda wx: tf.gather(*wx, axis=0), (Wsamples, X),
dtype=Wsamples.dtype)
# Now concatenate the resulting features on the last axis
f_dims = int(np.prod(features.shape[2:])) # need this for placeholders
Net = tf.reshape(features, [n_samples, n_batch, f_dims])
# Regularizers
KL = kl_sum(self.qW, self.pW)
return Net, KL
评论列表
文章目录