layers_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_embedding_int8(self):
    weights = np.array([[1, 2], [3, 4]], dtype='float32')
    embedding = tdl.Embedding(2, 2, initializer=weights)
    with self.test_session() as sess:
      embeddings = [embedding(tf.constant([x], dtype=tf.int8))
                    for x in [0, 1, 7, -5]]
      sess.run(tf.global_variables_initializer())
      self.assertAllEqual([[[1, 2]], [[3, 4]], [[3, 4]], [[3, 4]]],
                          sess.run(embeddings))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号