def __init__(self, conf):
self.conf = conf
if self.conf.act == "linear":
self.conf.act = linear
elif self.conf.act == "sigmoid":
self.conf.act = sigmoid
elif self.conf.act == "relu":
self.conf.act = rectify
elif self.conf.act == "tanh":
self.conf.act = tanh
else:
raise ValueError("Unknown activation function", self.conf.act)
input_var_first = T.matrix('inputs1')
input_var_second = T.matrix('inputs2')
target_var = T.matrix('targets')
# create network
self.autoencoder, encoder_first, encoder_second = self.__create_toplogy__(input_var_first, input_var_second)
self.out = get_output(self.autoencoder)
loss = squared_error(self.out, target_var)
loss = loss.mean()
params = get_all_params(self.autoencoder, trainable=True)
updates = nesterov_momentum(loss, params, learning_rate=self.conf.lr, momentum=self.conf.momentum)
# training function
self.train_fn = theano.function([input_var_first, input_var_second, target_var], loss, updates=updates)
# fuction to reconstruct
test_reconstruction = get_output(self.autoencoder, deterministic=True)
self.reconstruction_fn = theano.function([input_var_first, input_var_second], test_reconstruction)
# encoding function
test_encode = get_output([encoder_first, encoder_second], deterministic=True)
self.encoding_fn = theano.function([input_var_first, input_var_second], test_encode)
# utils
blas = lambda name, ndarray: scipy.linalg.get_blas_funcs((name,), (ndarray,))[0]
self.blas_nrm2 = blas('nrm2', np.array([], dtype=float))
self.blas_scal = blas('scal', np.array([], dtype=float))
# load weights if necessary
if self.conf.load_model is not None:
self.load_model()
评论列表
文章目录