def get_error_dict(self, data_iterator):
if len(self.error_func_dict) > 0:
l = {}
for key in self.error_func_dict:
ret = 0
old_mode = self.mode
self.set_mode('predict')
data_iterator.begin(do_shuffle=False)
while True:
ret += self.error_func_dict[key](*(data_iterator.get_batch()))
data_iterator.next()
if data_iterator.no_batch_left():
break
self.set_mode(old_mode)
l['key'] = ret / (data_iterator.total()*data_iterator.num_segments)
return l
#else: # disable, since only for binary predictions
#error = 0
#old_mode = self.mode
#self.set_mode('predict')
#data_iterator.begin(do_shuffle=False)
#while True:
# output = self.output_func_dict[0](*data_iterator.input_batch())
# target = data_iterator.output_batch()[0]
# pred = output.reshape((output.shape[0])) > 0.5
# target = target.reshape(target.shape[0]).astype("bool")
# error += (pred == target).sum()
# data_iterator.next()
# if data_iterator.no_batch_left():
# break
#error = 1 - (error / numpy_floatX(data_iterator.total()*data_iterator.num_segments))
#self.set_mode(old_mode)
#return [error]
评论列表
文章目录