test_ops.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:LiTeFlow 作者: petrux 项目源码 文件源码
def test_fit_to_less_width(self):
        """Fit a tensor to a smalles width (i.e. trimming).

        Given a 3D tensor of shape [batch, length, width], apply the
        `ops.fit()` operator to it with the a smaller `width` as the
        target one and check that the last axis of the tensor have been
        deleted.
        """
        batch = 2
        length = 5
        width = 4
        fit_width = 3
        delta = width - fit_width

        shape = [None, None, None]
        input_ = tf.placeholder(dtype=tf.float32, shape=shape)
        output = ops.fit(input_, fit_width)

        input_actual = np.random.rand(batch, length, width)  # pylint: disable=I0011,E1101
        delete_idx = [width - (i + 1) for i in range(delta)]
        output_expected = np.delete(input_actual, delete_idx, axis=2)  # pylint: disable=I0011,E1101
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            output_actual = sess.run(output, {input_: input_actual})
        self.assertAllClose(output_expected, output_actual)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号