def forward(self, x, lengths, hidden):
# Basket Encoding
ub_seqs = [] # users' basket sequence
for user in x: # x shape (batch of user, time_step, indice of product) nested lists
embed_baskets = []
for basket in user:
basket = torch.LongTensor(basket).resize_(1, len(basket))
basket = basket.cuda() if self.config.cuda else basket # use cuda for acceleration
basket = self.encode(torch.autograd.Variable(basket)) # shape: 1, len(basket), embedding_dim
embed_baskets.append(self.pool(basket, dim = 1))
# concat current user's all baskets and append it to users' basket sequence
ub_seqs.append(torch.cat(embed_baskets, 1)) # shape: 1, num_basket, embedding_dim
# Input for rnn
ub_seqs = torch.cat(ub_seqs, 0).cuda() if self.config.cuda else torch.cat(ub_seqs, 0) # shape: batch_size, max_len, embedding_dim
packed_ub_seqs = torch.nn.utils.rnn.pack_padded_sequence(ub_seqs, lengths, batch_first=True) # packed sequence as required by pytorch
# RNN
output, h_u = self.rnn(packed_ub_seqs, hidden)
dynamic_user, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # shape: batch_size, max_len, embedding_dim
return dynamic_user, h_u
评论列表
文章目录