def __init__(self, rnn_states, type_embedder, name='DelexicalizedDynamicPredicateEmbedder'):
"""Construct DelexicalizedDynamicPredicateEmbedder.
Args:
rnn_states (SequenceBatch): of shape (num_contexts, seq_length, rnn_state_dim)
type_embedder (TokenEmbedder)
name (str)
"""
self._type_embedder = type_embedder
with tf.name_scope(name):
# column indices of rnn_states (indexes time)
self._col_indices = FeedSequenceBatch() # (num_predicates, max_predicate_mentions)
# row indices of rnn_states (indexes utterance)
self._row_indices = tf.placeholder(dtype=tf.int32, shape=[None]) # (num_predicates,)
row_indices_expanded = expand_dims_for_broadcast(self._row_indices, self._col_indices.values)
# (num_predicates, max_predicate_mentions, rnn_state_dim)
rnn_states_selected = SequenceBatch(
gather_2d(rnn_states.values, row_indices_expanded, self._col_indices.values),
self._col_indices.mask)
# (num_predicates, rnn_state_dim)
rnn_embeds = reduce_mean(rnn_states_selected, allow_empty=True)
rnn_embeds = tf.verify_tensor_all_finite(rnn_embeds, "RNN-state-based embeddings")
self._type_seq_embedder = MeanSequenceEmbedder(type_embedder.embeds, name='TypeEmbedder')
self._embeds = tf.concat(1, [rnn_embeds, self._type_seq_embedder.embeds])
评论列表
文章目录