def forward(self, h_temporal, h_spatials):
'''
Forward pass for the model
params:
h_temporal : Hidden state of the temporal edgeRNN
h_spatials : Hidden states of all spatial edgeRNNs connected to the node.
'''
# Number of spatial edges
num_edges = h_spatials.size()[0]
# Embed the temporal edgeRNN hidden state
temporal_embed = self.temporal_edge_layer(h_temporal)
temporal_embed = temporal_embed.squeeze(0)
# Embed the spatial edgeRNN hidden states
spatial_embed = self.spatial_edge_layer(h_spatials)
# Dot based attention
attn = torch.mv(spatial_embed, temporal_embed)
# Variable length
temperature = num_edges / np.sqrt(self.attention_size)
attn = torch.mul(attn, temperature)
# Softmax
attn = torch.nn.functional.softmax(attn)
# Compute weighted value
weighted_value = torch.mv(torch.t(h_spatials), attn)
return weighted_value, attn
评论列表
文章目录