def user_representation(self, item_sequences):
"""
Compute user representation from a given sequence.
Returns
-------
tuple (all_representations, final_representation)
The first element contains all representations from step
-1 (no items seen) to t - 1 (all but the last items seen).
The second element contains the final representation
at step t (all items seen). This final state can be used
for prediction or evaluation.
"""
# Make the embedding dimension the channel dimension
sequence_embeddings = (self.item_embeddings(item_sequences)
.permute(0, 2, 1))
# Add a trailing dimension of 1
sequence_embeddings = (sequence_embeddings
.unsqueeze(3))
# Pad so that the CNN doesn't have the future
# of the sequence in its receptive field.
receptive_field_width = (self.kernel_width[0] +
(self.kernel_width[0] - 1) *
(self.dilation[0] - 1))
x = F.pad(sequence_embeddings,
(0, 0, receptive_field_width, 0))
x = self.nonlinearity(self.cnn_layers[0](x))
if self.residual_connections:
residual = F.pad(sequence_embeddings,
(0, 0, 1, 0))
x = x + residual
for (cnn_layer, kernel_width, dilation) in zip(self.cnn_layers[1:],
self.kernel_width[1:],
self.dilation[1:]):
receptive_field_width = (kernel_width +
(kernel_width - 1) *
(dilation - 1))
residual = x
x = F.pad(x, (0, 0, receptive_field_width - 1, 0))
x = self.nonlinearity(cnn_layer(x))
if self.residual_connections:
x = x + residual
x = x.squeeze(3)
return x[:, :, :-1], x[:, :, -1]
评论列表
文章目录