rnn_cell.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:bi-att-flow 作者: allenai 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """
        :param inputs: [N*B, I + B]
        :param state: [N*B, d]
        :param scope:
        :return: [N*B, d]
        """
        with tf.variable_scope(scope or self.__class__.__name__):
            d = self.state_size
            x = tf.slice(inputs, [0, 0], [-1, self._input_size])  # [N*B, I]
            mask = tf.slice(inputs, [0, self._input_size], [-1, -1])  # [N*B, B]
            B = tf.shape(mask)[1]
            prev_state = tf.expand_dims(tf.reshape(state, [-1, B, d]), 1)  # [N, B, d] -> [N, 1, B, d]
            mask = tf.tile(tf.expand_dims(tf.reshape(mask, [-1, B, B]), -1), [1, 1, 1, d])  # [N, B, B, d]
            # prev_state = self._reduce_func(tf.tile(prev_state, [1, B, 1, 1]), 2)
            prev_state = self._reduce_func(exp_mask(prev_state, mask), 2)  # [N, B, d]
            prev_state = tf.reshape(prev_state, [-1, d])  # [N*B, d]
            return self._cell(x, prev_state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号