def _batch_log_det(self):
rank = array_ops.size(self._shape_arg)
last_dim = math_ops.cast(
array_ops.gather(self._shape_arg, rank - 1), dtype=self.dtype)
log_det = (last_dim * math_ops.log(math_ops.abs(self._scale)) *
array_ops.ones(self.batch_shape(), dtype=self.dtype))
log_det.set_shape(self.get_batch_shape())
return log_det
operator_pd_identity.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录