def forward(self,x):
# Stem convolution
out = self.conv1(x)
# Allocate memory banks
m = [[None for _ in range(d)] for d in self.D]
module_index = 0
for i,(incoming_channels,outgoing_channels,g_values, bs, trans) in enumerate(zip(
self.incoming,self.outgoing, self.G, self.bank_sizes, [self.trans1,self.trans2,None])):
# Write to initial memory banks
for j in range(out.size(1) // (bs * self.N) ):
m[i][j] = out[:, j * bs * self.N : (j + 1) * bs * self.N]
for read,write,g in zip(incoming_channels,outgoing_channels,g_values):
# Cat read tensors
inp = torch.cat([m[i][index] for index in read], 1)
# Apply module and increment op index
out = self.mod[module_index](inp)
module_index += 1
for j, w in enumerate(write):
# Allocate dat memory if it's None
if m[i][w] is None:
m[i][w] = out[:, (j % (g // bs)) * (bs * self.N) : (j % (g // bs) + 1) * (bs * self.N)]
# Else, if already written, add to it.
else:
m[i][w] = m[i][w] + out[:, (j % (g // bs)) * (bs * self.N) : (j % (g // bs) + 1) * (bs * self.N)]
if trans is not None:
out = trans(torch.cat(m[i], 1))
else:
out = torch.cat(m[i], 1)
out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), out.size(2)))
out = F.log_softmax(self.fc(out))
return out
评论列表
文章目录