def hermitian_tuple_scorer(tuples_var, rank=None, n_emb=None, emb0=None, symmetry_coef=(1.0, 1.0),
learn_symmetry_coef=True):
"""
The Hermitian Scorer can learn embeddings for non-symmetric relations
:param tuples_var: TensorFlow variable that encodes the tuples as inputs
:param rank: size of the embeddings, including real and imaginary parts. The complex rank is half of it.
not needed if emb0 is given
:param n_emb: number of embeddings (not needed if initial embeddings are given)
:param emb0: initial embeddings (optional)
:param symmetry_coef: symmetry coefficient that equals np.inf for symmetric matrices, -np.inf for anti-symmetric
matrices and a real scalar for other cases.
:param learn_symmetry_coef: False if the symmetry coefficient is not learned [True by default]
:return: a pair (scoring TensorFlow graph, parameters). The parameters have the form
([n_emd*rank] float matrix, symmetry coef)
>>> embeddings = [[1., 1, 0, 3], [0, 1, 0, 1], [-1, 1, 1, 5]]
>>> tuples_var = tf.Variable([[0, 1], [1, 0], [0, 2], [2, 0], [1, 2], [2, 1]])
>>> (g, params) = hermitian_tuple_scorer(tuples_var, emb0=embeddings, symmetry_coef=(1.0, 0.0))
>>> print(tf_eval(g)) # symmetric form
[ 4. 4. 15. 15. 6. 6.]
>>> (g, params) = hermitian_tuple_scorer(tuples_var, emb0=embeddings, symmetry_coef=(0.0, 1.0))
>>> print(tf_eval(g)) # skewed (anti-symmetric) form
[-2. 2. 3. -3. 4. -4.]
>>> (g, params) = hermitian_tuple_scorer(tuples_var, emb0=embeddings, symmetry_coef=(1.0, 1.0))
>>> print(tf_eval(g)) # combination of the previous two forms
[ 2. 6. 18. 12. 10. 2.]
>>> (g, params) = hermitian_tuple_scorer(tuples_var, emb0=embeddings, symmetry_coef=(0.9, 0.1))
>>> print(tf_eval(g)) # close to symmetric
[ 3.39999986 3.79999995 13.80000019 13.19999981 5.79999971
4.99999952]
"""
emb0 = emb0 if emb0 is not None else np.random.normal(size=(n_emb, rank))
embeddings = tf.Variable(tf.cast(emb0, 'float32'), 'embeddings')
symmetry_coef = tf.Variable(symmetry_coef, name='symmetry_coef', trainable=learn_symmetry_coef)
params = (embeddings, symmetry_coef)
return sparse_hermitian_scoring(params, tuples_var), params
评论列表
文章目录