regression.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:denet 作者: lachlants 项目源码 文件源码
def get_target(self, model, samples, metas):

        yt_index=[]
        if len(self.output_shape) == 2:
            for b in range(len(metas)):
                yt_index.append(numpy.ravel_multi_index((b, metas[b]["image_class"]), self.output_shape))

        elif len(self.valid) > 0:
            for b in range(len(metas)):
                for v in range(len(self.valid)):
                    yt_index.append(numpy.ravel_multi_index((b, metas[b]["image_class"], v), self.output_shape))
        else:
            for b in range(len(metas)):
                cls = metas[b]["image_class"]
                for y in range(self.output_shape[2]):
                    for x in range(self.output_shape[3]):
                        yt_index.append(numpy.ravel_multi_index((b, metas[b]["image_class"], y, x), self.output_shape))

        return numpy.array(yt_index, dtype=numpy.int64), numpy.array([], dtype=theano.config.floatX)

    #return negative log-likelihood training cost (scalar)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号