def get_cost(self, data_iterator):
ret = 0
old_mode = self.mode
self.set_mode('predict')
data_iterator.begin(do_shuffle=False)
while True:
ret += self.cost_func(*(data_iterator.get_batch()))
data_iterator.next()
if data_iterator.no_batch_left():
break
self.set_mode(old_mode)
return ret / data_iterator.total()
评论列表
文章目录