def test_fit_to_less_width(self):
"""Fit a tensor to a smalles width (i.e. trimming).
Given a 3D tensor of shape [batch, length, width], apply the
`ops.fit()` operator to it with the a smaller `width` as the
target one and check that the last axis of the tensor have been
deleted.
"""
batch = 2
length = 5
width = 4
fit_width = 3
delta = width - fit_width
shape = [None, None, None]
input_ = tf.placeholder(dtype=tf.float32, shape=shape)
output = ops.fit(input_, fit_width)
input_actual = np.random.rand(batch, length, width) # pylint: disable=I0011,E1101
delete_idx = [width - (i + 1) for i in range(delta)]
output_expected = np.delete(input_actual, delete_idx, axis=2) # pylint: disable=I0011,E1101
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_actual = sess.run(output, {input_: input_actual})
self.assertAllClose(output_expected, output_actual)
评论列表
文章目录