def _stride_to_std(self, stride):
shape = convert_shape(stride.get_shape())
stride_flat = tf.reshape(stride, (-1, shape[-1]))
y, x = stride_flat[..., 0], stride_flat[..., 1]
features = [
tf.ones_like(y),
y, y ** 2, y ** 3, y ** 4,
x, x ** 2, x ** 3, x ** 4,
y * x, y * x ** 2, y ** 2 * x,
y * x ** 3, y ** 2 * x ** 2, y ** 3 * x
]
features = tf.concat(axis=1, values=[f[..., tf.newaxis] for f in features])
sigma_flat = tf.matmul(features, self.weights)
return tf.reshape(sigma_flat, shape)
评论列表
文章目录