def flatten2d(inputs, name=None):
""" Flatten tensor to two dimensions (batch_size, item_vector_size) """
x = tf.convert_to_tensor(inputs)
dims = tf.reduce_prod(tf.shape(x)[1:])
x = tf.reshape(x, [-1, dims], name=name)
return x
评论列表
文章目录