dream.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:DREAM 作者: LaceyChen17 项目源码 文件源码
def forward(self, x, lengths, hidden):
        # Basket Encoding 
        ub_seqs = [] # users' basket sequence
        for user in x: # x shape (batch of user, time_step, indice of product) nested lists
            embed_baskets = []
            for basket in user:
                basket = torch.LongTensor(basket).resize_(1, len(basket))
                basket = basket.cuda() if self.config.cuda else basket # use cuda for acceleration
                basket = self.encode(torch.autograd.Variable(basket)) # shape: 1, len(basket), embedding_dim
                embed_baskets.append(self.pool(basket, dim = 1))
            # concat current user's all baskets and append it to users' basket sequence
            ub_seqs.append(torch.cat(embed_baskets, 1)) # shape: 1, num_basket, embedding_dim

        # Input for rnn 
        ub_seqs = torch.cat(ub_seqs, 0).cuda() if self.config.cuda else torch.cat(ub_seqs, 0) # shape: batch_size, max_len, embedding_dim
        packed_ub_seqs = torch.nn.utils.rnn.pack_padded_sequence(ub_seqs, lengths, batch_first=True) # packed sequence as required by pytorch

        # RNN
        output, h_u = self.rnn(packed_ub_seqs, hidden)
        dynamic_user, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # shape: batch_size, max_len, embedding_dim
        return dynamic_user, h_u
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号