noTraceBaseline.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:TikZ 作者: ellisk42 项目源码 文件源码
def loss(self,examples):
        # IMPORTANT: Sort the examples by their size. recurrent network stuff needs this
        examples.sort(key = lambda e: len(e.tokens), reverse = True)

        x = variable(np.array([ e.sequence.draw() for e in examples], dtype = np.float32))

        x = x.unsqueeze(1) # insert the channel

        imageFeatures = self.encoder(x)

        inputs, sizes, T = self.decoder.buildCaptions([ e.tokens for e in examples ])

        outputDistributions = self.decoder(imageFeatures, inputs, sizes)

        T = pack_padded_sequence(T, sizes, batch_first = True)[0]

        return F.cross_entropy(outputDistributions, T)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号