attention_ops.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:hart 作者: akosiorek 项目源码 文件源码
def _stride_to_std(self, stride):
        shape = convert_shape(stride.get_shape())
        stride_flat = tf.reshape(stride, (-1, shape[-1]))
        y, x = stride_flat[..., 0], stride_flat[..., 1]
        features = [
            tf.ones_like(y),
            y, y ** 2, y ** 3, y ** 4,
            x, x ** 2, x ** 3, x ** 4,
               y * x, y * x ** 2, y ** 2 * x,
               y * x ** 3, y ** 2 * x ** 2, y ** 3 * x
        ]

        features = tf.concat(axis=1, values=[f[..., tf.newaxis] for f in features])
        sigma_flat = tf.matmul(features, self.weights)
        return tf.reshape(sigma_flat, shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号