def getSocialTensor(self, grid, hidden_states):
'''
Computes the social tensor for a given grid mask and hidden states of all peds
params:
grid : Grid masks
hidden_states : Hidden states of all peds
'''
# Number of peds
numNodes = grid.size()[0]
# Construct the variable
social_tensor = Variable(torch.zeros(numNodes, self.grid_size*self.grid_size, self.rnn_size).cuda())
# For each ped
for node in range(numNodes):
# Compute the social tensor
social_tensor[node] = torch.mm(torch.t(grid[node]), hidden_states)
# Reshape the social tensor
social_tensor = social_tensor.view(numNodes, self.grid_size*self.grid_size*self.rnn_size)
return social_tensor
评论列表
文章目录