tetris_theano.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:reinforcement_learning 作者: andreweskeclarke 项目源码 文件源码
def compile(self):
        x_train = T.tensor4('x_train')
        actions_train = T.matrix('actions_train')
        y_train = T.matrix('y_train')
        cost_function = self.squared_error(x_train, actions_train, y_train)
        self.train_function = theano.function([x_train, actions_train, y_train],
                                cost_function,
                                updates=self.sgd(cost_function, self.params),
                                on_unused_input='ignore',
                                allow_input_downcast=True)
        x_pred = T.tensor3('x_pred')
        actions_pred = T.vector('actions_pred')
        output_function = self.output(x_pred, actions_pred)
        self.predict_function = theano.function([x_pred, actions_pred],
                                                output_function,
                                                on_unused_input='ignore',
                                                allow_input_downcast=True)
        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号