def _batch_sqrt_log_abs_det(self):
rank = array_ops.size(self._shape_arg)
last_dim = math_ops.cast(
array_ops.gather(self._shape_arg, rank - 1), dtype=self.dtype)
sqrt_log_abs_det = 0.5 * last_dim * math_ops.log(
math_ops.abs(self._scale)) * array_ops.ones(
self.batch_shape(), dtype=self.dtype)
sqrt_log_abs_det.set_shape(self.get_batch_shape())
return sqrt_log_abs_det
operator_pd_identity.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录