parse_model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def __init__(self, rnn_states, type_embedder, name='DelexicalizedDynamicPredicateEmbedder'):
        """Construct DelexicalizedDynamicPredicateEmbedder.

        Args:
            rnn_states (SequenceBatch): of shape (num_contexts, seq_length, rnn_state_dim)
            type_embedder (TokenEmbedder)
            name (str)
        """
        self._type_embedder = type_embedder

        with tf.name_scope(name):
            # column indices of rnn_states (indexes time)
            self._col_indices = FeedSequenceBatch()  # (num_predicates, max_predicate_mentions)

            # row indices of rnn_states (indexes utterance)
            self._row_indices = tf.placeholder(dtype=tf.int32, shape=[None])  # (num_predicates,)
            row_indices_expanded = expand_dims_for_broadcast(self._row_indices, self._col_indices.values)

            # (num_predicates, max_predicate_mentions, rnn_state_dim)
            rnn_states_selected = SequenceBatch(
                gather_2d(rnn_states.values, row_indices_expanded, self._col_indices.values),
                self._col_indices.mask)

            # (num_predicates, rnn_state_dim)
            rnn_embeds = reduce_mean(rnn_states_selected, allow_empty=True)
            rnn_embeds = tf.verify_tensor_all_finite(rnn_embeds, "RNN-state-based embeddings")

            self._type_seq_embedder = MeanSequenceEmbedder(type_embedder.embeds, name='TypeEmbedder')
            self._embeds = tf.concat(1, [rnn_embeds, self._type_seq_embedder.embeds])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号