model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:online_action 作者: zhenyangli 项目源码 文件源码
def get_error_dict(self, data_iterator):
        if len(self.error_func_dict) > 0:
            l = {}
            for key in self.error_func_dict:
                ret = 0
                old_mode = self.mode
                self.set_mode('predict')
                data_iterator.begin(do_shuffle=False)
                while True:
                    ret += self.error_func_dict[key](*(data_iterator.get_batch()))
                    data_iterator.next()
                    if data_iterator.no_batch_left():
                        break
                self.set_mode(old_mode)
                l['key'] = ret / data_iterator.total()
            return l
        #else: # disable, since only for binary predictions
            #error = 0
            #old_mode = self.mode
            #self.set_mode('predict')
            #data_iterator.begin(do_shuffle=False)
            #while True:
            #    output = self.output_func_dict[0](*data_iterator.input_batch())
            #    target = data_iterator.output_batch()[0]
            #    pred = output.reshape((output.shape[0])) > 0.5
            #    target = target.reshape(target.shape[0]).astype("bool")
            #    error += (pred == target).sum()
            #    data_iterator.next()
            #    if data_iterator.no_batch_left():
            #        break
            #error = 1 - (error / numpy_floatX(data_iterator.total()))
            #self.set_mode(old_mode)
            #return [error]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号