lstm_2layer.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:mimic3-benchmarks 作者: YerevaNN 项目源码 文件源码
def step(self, mode):
        if mode == "train" and self.mode == "test":
            raise Exception("Cannot train during test mode")

        if mode == "train":
            theano_fn = self.train_fn
            batch_gen = self.train_batch_gen
        elif mode == "test":    
            theano_fn = self.test_fn
            batch_gen = self.test_batch_gen
        else:
            raise Exception("Invalid mode")

        data = next(batch_gen)
        ret = theano_fn(*data)

        return {"prediction": np.array(ret[0]),
                "answers": data[-1],
                "current_loss": ret[1],
                "log": ""}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号