def dot(x, y):
'''Multiplies 2 tensors.
When attempting to multiply a ND tensor
with a ND tensor, reproduces the Theano behavior
(e.g. (2, 3).(4, 3, 5) = (2, 4, 5))
'''
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
x_shape = (-1,) + int_shape(x)[1:]
y_shape = int_shape(y)
y_permute_dim = list(range(ndim(y)))
y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
xt = tf.reshape(x, [-1, x_shape[-1]])
yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
return tf.reshape(tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
if is_sparse(x):
out = tf.sparse_tensor_dense_matmul(x, y)
else:
out = tf.matmul(x, y)
return out
评论列表
文章目录