base_network.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:ImageCaptioning 作者: rkuga 项目源码 文件源码
def step(self,perm,batch_index, mode, epoch): 
            if mode =='train':
                data, label=self.read_batch(perm,batch_index,self.train_data)
            else:
                data, label=self.read_batch(perm,batch_index,self.test_data)

            data = Variable(cuda.to_gpu(data))
            yl = self.network(data)

            label=Variable(cuda.to_gpu(label))

            L_network = F.softmax_cross_entropy(yl, label)
            A_network = F.accuracy(yl, label)

            if mode=='train':
                self.o_network.zero_grads()
                L_network.backward()
                self.o_network.update()


            return {"prediction": yl.data.get(),
                    "current_loss": L_network.data.get(),
                    "current_accuracy": A_network.data.get(),
            }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号