def multilinear_grad(emb: tf.Tensor, tuples: tf.Tensor, score=False) -> tf.Tensor:
tuple_shape = [d.value for d in tuples.get_shape()]
# if len(tuple_shape) > 2:
# n = np.prod(tuple_shape[:-1])
# tuples = tf.reshape(tuples, (n, -1))
# n = tuples.get_shape()[0].value
order = tuples.get_shape()[2].value
rank = emb.get_shape()[-1].value
if order == 2:
if score:
emb_sel = tf.gather(emb, tuples)
grad_score = tf.reshape(tf.reverse(emb_sel, [False, False, True, False]), tuple_shape[:-1] + [2, rank])
prod = tf.reduce_prod(emb_sel, 2)
preds = tf.reshape(tf.reduce_sum(prod, 2), tuple_shape[:-1])
return grad_score, preds
raise NotImplementedError('Todo')
# grad_score0 = tf.reverse(emb_sel, [False, True, False]) # reverse the row and column embeddings
# prod = tf.reduce_prod(emb_sel, 1)
# preds = tf.reshape(tf.reduce_sum(prod, 1), tuple_shape[:-1])
#
# preds = tf.reshape(tf.reduce_sum(prod, 1), tuple_shape[:-1])
# else: # derivative of a product
# prod = tf.reduce_prod(emb_sel, 1)
# grad_score0 = tf.tile(tf.reshape(prod, (n, 1, rank)), (1, order, 1)) / emb_sel
# grad_score = tf.reshape(grad_score0, tuple_shape + [rank])
# if score:
# prod = tf.reduce_prod(emb_sel, 1)
# preds = tf.reshape(tf.reduce_sum(prod, 1), tuple_shape[:-1])
# return grad_score, preds
# else:
# return grad_score
评论列表
文章目录