models.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:sdp 作者: tansey 项目源码 文件源码
def get_log_probs(self, indices, splits, dims):
        '''Get the necessary nodes from the tree, calculate the log probs, and reshape appropriately'''
        dim1size = int(np.prod(dims))
        sampled_W = tf.transpose(tf.gather(self._W, indices), [0,2,1]) # [batchsize, inputlayersize, dim1size]
        sampled_b = tf.gather(self._b, indices) # [batchsize, dim1size]
        # input_layer is [batchsize, inputlayersize]
        # sampled_W is [batchsize, inputlayersize, dim1size]
        # sampled_q is [batchsize, dim1size] corresponding to q = X*W + b
        sampled_q = tf.reshape(tf.matmul(tf.expand_dims(self._input_layer,1), sampled_W), 
                                    [-1, dim1size]) + sampled_b
        sampled_probs = tf.reciprocal(1 + tf.exp(-sampled_q))
        log_probs = tf.log(tf.clip_by_value(tf.where(splits > 0, sampled_probs, 1-sampled_probs), 1e-10, 1.0))
        log_probs_dims = tf.reshape(log_probs, [-1] + dims)
        return tf.reduce_sum(log_probs_dims, axis=[len(dims)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号