representations.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:spotlight 作者: maciejkula 项目源码 文件源码
def user_representation(self, item_sequences):
        """
        Compute user representation from a given sequence.

        Returns
        -------

        tuple (all_representations, final_representation)
            The first element contains all representations from step
            -1 (no items seen) to t - 1 (all but the last items seen).
            The second element contains the final representation
            at step t (all items seen). This final state can be used
            for prediction or evaluation.
        """

        # Make the embedding dimension the channel dimension
        sequence_embeddings = (self.item_embeddings(item_sequences)
                               .permute(0, 2, 1))
        # Add a trailing dimension of 1
        sequence_embeddings = (sequence_embeddings
                               .unsqueeze(3))

        # Pad so that the CNN doesn't have the future
        # of the sequence in its receptive field.
        receptive_field_width = (self.kernel_width[0] +
                                 (self.kernel_width[0] - 1) *
                                 (self.dilation[0] - 1))

        x = F.pad(sequence_embeddings,
                  (0, 0, receptive_field_width, 0))
        x = self.nonlinearity(self.cnn_layers[0](x))

        if self.residual_connections:
            residual = F.pad(sequence_embeddings,
                             (0, 0, 1, 0))
            x = x + residual

        for (cnn_layer, kernel_width, dilation) in zip(self.cnn_layers[1:],
                                                       self.kernel_width[1:],
                                                       self.dilation[1:]):
            receptive_field_width = (kernel_width +
                                     (kernel_width - 1) *
                                     (dilation - 1))
            residual = x
            x = F.pad(x, (0, 0, receptive_field_width - 1, 0))
            x = self.nonlinearity(cnn_layer(x))

            if self.residual_connections:
                x = x + residual

        x = x.squeeze(3)

        return x[:, :, :-1], x[:, :, -1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号