def grams(X):
dim_ordering = K.image_dim_ordering()
if dim_ordering == 'tf':
X = K.permute_dimensions(X, (0, 3, 1, 2))
(samples, c, h, w) = get_shape(X)
X_reshaped = K.reshape(X, (-1, c, h * w))
X_T = K.permute_dimensions(X_reshaped, (0, 2, 1))
if K._BACKEND == 'theano':
X_gram = T.batched_dot(X_reshaped, X_T)
else:
X_gram = tf.batch_matmul(X_reshaped, X_T)
X_gram /= c * h * w
return X_gram
评论列表
文章目录