loom_test.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_constant_network_with_tags(self):
    shape1 = loom.TypeShape('int64', (3,), 'alpha')
    shape2 = loom.TypeShape('int64', (3,), 'beta')
    value1 = np.array([1, 2, 3], dtype='int64')
    value2 = np.array([4, 5, 6], dtype='int64')
    ops = {'add1': BinaryLoomOp(shape1, tf.add),
           'add2': BinaryLoomOp(shape2, tf.add)}
    the_loom = loom.Loom(named_ops=ops)
    output_tensor1 = the_loom.output_tensor(shape1)
    output_tensor2 = the_loom.output_tensor(shape2)
    with self.test_session():
      weaver = the_loom.make_weaver()
      c1 = weaver(value1, tag='alpha')
      c2 = weaver(value2, tag='beta')
      result1 = output_tensor1.eval(
          feed_dict=weaver.build_feed_dict([c2, c1]))
      result2 = output_tensor2.eval(
          feed_dict=weaver.build_feed_dict([c2, c1]))
    self.assertTrue((result1[0] == value1).all())
    self.assertTrue((result2[0] == value2).all())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号