def forward(self, input):
if not self.aux_loss:
return self.f(input)
else:
identity = torch.from_numpy(np.array([[1,0,0], [0,1,0]], dtype=np.float32))
batch_identity = torch.zeros([input.size(0), 2,3])
for i in range(input.size(0)):
batch_identity[i] = identity
if input.is_cuda:
batch_identity = Variable(batch_identity.cuda())
else:
batch_identity = Variable(batch_identity)
loss = torch.mul(input - batch_identity, input - batch_identity)
loss = torch.sum(loss,1)
loss = torch.sum(loss,2)
return self.f(input), loss.view(-1,1)
评论列表
文章目录