def __len__(self):
return self.D
# def save(self, savedir):
# state = {"version": 1,
# "seed": self.seed,
# "use_mask_as_input": self.use_mask_as_input,
# "batch_size": self.batch_size,
# "shared_batch_count": self.shared_batch_count.get_value(),
# "rng": pickle.dumps(self.rng),
# "shared_batch_mask": self._shared_mask_o_lt_d.get_value(),
# }
# np.savez(pjoin(savedir, 'mini_batch_scheduler_with_autoregressive_mask.npz'), **state)
# def load(self, loaddir):
# state = np.load(pjoin(loaddir, 'mini_batch_scheduler_with_autoregressive_mask.npz'))
# self.batch_size = state["batch_size"]
# self.shared_batch_count.set_value(state["shared_batch_count"])
# self.rng = pickle.loads(state["rng"])
# self._shared_mask_o_lt_d.set_value(state["shared_batch_mask"])
评论列表
文章目录