def semantic_matrix(argv):
assert len(argv) == 2
q = argv[0]
a = argv[1]
q_sqrt = K.sqrt((q ** 2).sum(axis=2, keepdims=True))
a_sqrt = K.sqrt((a ** 2).sum(axis=2, keepdims=True))
denominator = K.batch_dot(q_sqrt, K.permute_dimensions(a_sqrt, [0,2,1]))
return K.batch_dot(q, K.permute_dimensions(a, [0,2,1])) / (denominator + SAFE_EPSILON)
# ??idx??????
# ??????batch index????????
# ??https://groups.google.com/forum/#!topic/theano-users/7gUdN6E00Dc
# ??argmax???2 - axis
# ??theano??a > 0????????[1,1,0]?????????????
# ?bool???????????
# ??????????T.set_subtensor(ib[(ib < 0).nonzero()], 0)
评论列表
文章目录