def step(self, mode):
if mode == "train" and self.mode == "test":
raise Exception("Cannot train during test mode")
if mode == "train":
theano_fn = self.train_fn
batch_gen = self.train_batch_gen
elif mode == "test":
theano_fn = self.test_fn
batch_gen = self.test_batch_gen
else:
raise Exception("Invalid mode")
data = next(batch_gen)
ys = data[-1]
data = data[:-1]
ret = theano_fn(*data)
return {"prediction": np.exp(ret[0]) - 1,
"answers": ys,
"current_loss": ret[1],
"loss_reg": ret[2],
"loss_mse": ret[1] - ret[2],
"log": ""}
评论列表
文章目录