def outer_product(*inputs):
"""Computes outer product.
Args:
inputs: a list of 1-D `Tensor` (vector)
"""
inputs = list(inputs)
order = len(inputs)
for idx, input_ in enumerate(inputs):
if len(input_.get_shape()) == 1:
inputs[idx] = tf.reshape(input_, [-1, 1] if idx % 2 == 0 else [1, -1])
if order == 2:
output = tf.multiply(inputs[0], inputs[1])
elif order == 3:
size = []
idx = 1
for i in xrange(order):
size.append(inputs[i].get_shape()[0])
output = tf.zeros(size)
u, v, w = inputs[0], inputs[1], inputs[2]
uv = tf.multiply(inputs[0], inputs[1])
for i in xrange(self.size[-1]):
output = tf.scatter_add(output, [0,0,i], uv)
return output
评论列表
文章目录