model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:srnn-pytorch 作者: vvanirudh 项目源码 文件源码
def forward(self, h_temporal, h_spatials):
        '''
        Forward pass for the model
        params:
        h_temporal : Hidden state of the temporal edgeRNN
        h_spatials : Hidden states of all spatial edgeRNNs connected to the node.
        '''
        # Number of spatial edges
        num_edges = h_spatials.size()[0]

        # Embed the temporal edgeRNN hidden state
        temporal_embed = self.temporal_edge_layer(h_temporal)
        temporal_embed = temporal_embed.squeeze(0)

        # Embed the spatial edgeRNN hidden states
        spatial_embed = self.spatial_edge_layer(h_spatials)

        # Dot based attention
        attn = torch.mv(spatial_embed, temporal_embed)

        # Variable length
        temperature = num_edges / np.sqrt(self.attention_size)
        attn = torch.mul(attn, temperature)

        # Softmax
        attn = torch.nn.functional.softmax(attn)

        # Compute weighted value
        weighted_value = torch.mv(torch.t(h_spatials), attn)

        return weighted_value, attn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号