def forward(self, X):
if isinstance(X, tuple) or isinstance(X, list):
# for self-parallel mode, please see encoding.nn
return my_data_parallel(self, X)
elif not isinstance(X, Variable):
raise RuntimeError('unknown input type')
# input X is a 4D tensor
assert(X.size(1)==self.D)
if X.dim() == 3:
# BxDxN
B, N, K, D = X.size(0), X.size(2), self.K, self.D
X = X.transpose(1,2).contiguous()
elif X.dim() == 4:
# BxDxHxW
B, N, K, D = X.size(0), X.size(2)*X.size(3), self.K, self.D
X = X.view(B,D,-1).transpose(1,2).contiguous()
else:
raise RuntimeError('Encoding Layer unknown input dims!')
# assignment weights NxKxD
A = F.softmax(scaledL2(X, self.codewords, self.scale), dim=1)
# aggregate
E = aggregate(A, X, self.codewords)
return E
评论列表
文章目录