def loss(self,examples):
# IMPORTANT: Sort the examples by their size. recurrent network stuff needs this
examples.sort(key = lambda e: len(e.tokens), reverse = True)
x = variable(np.array([ e.sequence.draw() for e in examples], dtype = np.float32))
x = x.unsqueeze(1) # insert the channel
imageFeatures = self.encoder(x)
inputs, sizes, T = self.decoder.buildCaptions([ e.tokens for e in examples ])
outputDistributions = self.decoder(imageFeatures, inputs, sizes)
T = pack_padded_sequence(T, sizes, batch_first = True)[0]
return F.cross_entropy(outputDistributions, T)
评论列表
文章目录