def get_log_probs(self, indices, splits, dims):
'''Get the necessary nodes from the tree, calculate the log probs, and reshape appropriately'''
dim1size = int(np.prod(dims))
sampled_W = tf.transpose(tf.gather(self._W, indices), [0,2,1]) # [batchsize, inputlayersize, dim1size]
sampled_b = tf.gather(self._b, indices) # [batchsize, dim1size]
# input_layer is [batchsize, inputlayersize]
# sampled_W is [batchsize, inputlayersize, dim1size]
# sampled_q is [batchsize, dim1size] corresponding to q = X*W + b
sampled_q = tf.reshape(tf.matmul(tf.expand_dims(self._input_layer,1), sampled_W),
[-1, dim1size]) + sampled_b
sampled_probs = tf.reciprocal(1 + tf.exp(-sampled_q))
log_probs = tf.log(tf.clip_by_value(tf.where(splits > 0, sampled_probs, 1-sampled_probs), 1e-10, 1.0))
log_probs_dims = tf.reshape(log_probs, [-1] + dims)
return tf.reduce_sum(log_probs_dims, axis=[len(dims)])
评论列表
文章目录