def step(self, mode):
if mode == "train" and self.mode == "test":
raise Exception("Cannot train during test mode")
if mode == "train":
theano_fn = self.train_fn
batch_gen = self.train_batch_gen
elif mode == "test":
theano_fn = self.test_fn
batch_gen = self.test_batch_gen
else:
raise Exception("Invalid mode")
data = next(batch_gen)
ret = theano_fn(*data)
return {"prediction": np.array(ret[0]),
"answers": data[-1],
"current_loss": ret[1],
"log": ""}
评论列表
文章目录