def user_representation(self, item_sequences):
"""
Compute user representation from a given sequence.
Returns
-------
tuple (all_representations, final_representation)
The first element contains all representations from step
-1 (no items seen) to t - 1 (all but the last items seen).
The second element contains the final representation
at step t (all items seen). This final state can be used
for prediction or evaluation.
"""
# Make the embedding dimension the channel dimension
sequence_embeddings = (self.item_embeddings(item_sequences)
.permute(0, 2, 1))
# Add a trailing dimension of 1
sequence_embeddings = (sequence_embeddings
.unsqueeze(3))
# Pad it with zeros from left
sequence_embeddings = F.pad(sequence_embeddings,
(0, 0, 1, 0))
# Average representations, ignoring padding.
sequence_embedding_sum = torch.cumsum(sequence_embeddings, 2)
non_padding_entries = (
torch.cumsum((sequence_embeddings != 0.0).float(), 2)
.expand_as(sequence_embedding_sum)
)
user_representations = (
sequence_embedding_sum / (non_padding_entries + 1)
).squeeze(3)
return user_representations[:, :, :-1], user_representations[:, :, -1]
评论列表
文章目录