def test_embedding_int8(self):
weights = np.array([[1, 2], [3, 4]], dtype='float32')
embedding = tdl.Embedding(2, 2, initializer=weights)
with self.test_session() as sess:
embeddings = [embedding(tf.constant([x], dtype=tf.int8))
for x in [0, 1, 7, -5]]
sess.run(tf.global_variables_initializer())
self.assertAllEqual([[[1, 2]], [[3, 4]], [[3, 4]], [[3, 4]]],
sess.run(embeddings))
评论列表
文章目录