attention_rnn_cell.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:neural-chat 作者: henriblancke 项目源码 文件源码
def __init__(self, cell, attention_states, batch_size, embedding, initializer=None, num_heads=1, scope=None):
        if not isinstance(cell, tf.nn.rnn_cell.RNNCell):
            raise TypeError("The parameter cell is not RNNCell.")
        self._cell = cell
        self._attention_states = attention_states
        self.embedding = embedding

        with variable_scope.variable_scope(scope or "attention_decoder"):
            # batch_size = attention_states.get_shape()[0].value
            attn_length = attention_states.get_shape()[1].value
            attn_size = attention_states.get_shape()[2].value

            hidden = array_ops.reshape(
                attention_states, [-1, attn_length, 1, attn_size])
            hidden_features = []
            v = []
            attention_vec_size = attn_size  # Size of query vectors for attention.
            for a in xrange(num_heads):
                k = variable_scope.get_variable("AttnW_%d" % a,
                                                [1, 1, attn_size, attention_vec_size])
                hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
                v.append(variable_scope.get_variable("AttnV_%d" % a,
                                                     [attention_vec_size]))

            batch_attn_size = array_ops.pack([batch_size, attn_size])
            self.attns = [array_ops.zeros(batch_attn_size, dtype=dtypes.float32)
                          for _ in xrange(num_heads)]

            def attention(query):
                """Put attention masks on hidden using hidden_features and query."""
                ds = []  # Results of attention reads will be stored here.
                for a in xrange(num_heads):
                    with variable_scope.variable_scope("Attention_%d" % a):
                        y = linear(query, attention_vec_size, True)
                        y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
                        # Attention mask is a softmax of v^T * tanh(...).
                        s = math_ops.reduce_sum(
                            v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3])
                        a = nn_ops.softmax(s)
                        # Now calculate the attention-weighted vector d.
                        d = math_ops.reduce_sum(
                            array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,
                            [1, 2])
                        ds.append(array_ops.reshape(d, [-1, attn_size]))
                return ds

            self.attention = attention
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号