def __init__(self):
metric_names = ['Loss','L2','Accuracy']
super(Fr3dNetTrainer, self).__init__(metric_names)
tensor5 = T.TensorType(theano.config.floatX, (False,) * 5)
input_var = tensor5('inputs')
target_var = T.ivector('targets')
logging.info("Defining network")
net = fr3dnet.define_network(input_var)
self.network = net
train_fn, val_fn, l_r = fr3dnet.define_updates(net, input_var, target_var)
self.train_fn = train_fn
self.val_fn = val_fn
self.l_r = l_r
评论列表
文章目录