fft_tree_constrained_inference.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:wip-constrained-extractor 作者: brain-research 项目源码 文件源码
def compute_down_msg(self, inc_node_msg, node_to_span_off_belief_idx,
                       node_to_span_on_start_belief_idx,
                       node_to_span_on_end_belief_idx,
                       parent_on_down_to_sum_tree_idx,
                       parent_off_down_to_sum_tree_idx):
    """Compute downward BP messages for this layer of the tree.

    Args:
      inc_node_msg: incoming messages from parent variables.
      node_to_span_off_belief_idx: map from node marginals at this layer to
        corresponding span-off marginals.
      node_to_span_on_start_belief_idx: map marking start of each span marginal.
      node_to_span_on_end_belief_idx: map marking end of each span marginal.
      parent_on_down_to_sum_tree_idx: map from marginal of parent-on
        variable down to child variable.
      parent_off_down_to_sum_tree_idx: map from marginal of parent-off
        variable down to child variable.

    Returns:
      span_off_marginals:
      out_msg:
    """

    node_marginals = self.up_node_msg * inc_node_msg

    span_off_beliefs = padded_gather_nd(node_marginals,
                                        node_to_span_off_belief_idx, 3, 4)

    cumulative_node_beliefs = tf.cumsum(node_marginals, 2)

    span_on_start_cumulative_belief = padded_gather_nd(
        cumulative_node_beliefs, node_to_span_on_start_belief_idx, 3, 4)
    span_on_end_cumulative_belief = padded_gather_nd(
        cumulative_node_beliefs, node_to_span_on_end_belief_idx, 3, 4)

    span_on_beliefs = (
        span_on_end_cumulative_belief - span_on_start_cumulative_belief)

    span_belief_normalizer = span_on_beliefs + span_off_beliefs

    span_off_marginals = su.safe_divide(span_off_beliefs,
                                        span_belief_normalizer)

    out_msg = padded_gather_nd(inc_node_msg, parent_on_down_to_sum_tree_idx, 3,
                               4)

    out_msg += padded_gather_nd(inc_node_msg, parent_off_down_to_sum_tree_idx,
                                3, 4)

    return span_off_marginals, out_msg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号