def compile(self):
x_train = T.tensor4('x_train')
actions_train = T.matrix('actions_train')
y_train = T.matrix('y_train')
cost_function = self.squared_error(x_train, actions_train, y_train)
self.train_function = theano.function([x_train, actions_train, y_train],
cost_function,
updates=self.sgd(cost_function, self.params),
on_unused_input='ignore',
allow_input_downcast=True)
x_pred = T.tensor3('x_pred')
actions_pred = T.vector('actions_pred')
output_function = self.output(x_pred, actions_pred)
self.predict_function = theano.function([x_pred, actions_pred],
output_function,
on_unused_input='ignore',
allow_input_downcast=True)
return self
tetris_theano.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录