loom_test.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_two_ops_network_tagged_named_tensorx(self):
    shape = loom.TypeShape('int64', (3,), tag='x')
    ops = {'add': BinaryLoomOp(shape, tf.add),
           'mul': BinaryLoomOp(shape, tf.multiply)}
    named_tensors = {
        'c1': (tf.constant(np.array([1, 2, 3], dtype='int64')), 'x'),
        'c2': (tf.constant(np.array([2, 4, 6], dtype='int64')), 'x'),
        'c3': (tf.constant(np.array([3, 6, 9], dtype='int64')), 'x')
    }
    the_loom = loom.Loom(named_ops=ops, named_tensors=named_tensors)
    output_tensor = the_loom.output_tensor(shape)
    with self.test_session():
      weaver = the_loom.make_weaver()
      sum_2_3 = weaver.add(weaver.c2, weaver.c3)
      sum_12_13 = weaver.mul(weaver.c1, sum_2_3)
      result = output_tensor.eval(
          feed_dict=weaver.build_feed_dict([sum_12_13]))
    self.assertTrue((result == np.array([[5, 20, 45]], dtype='int64')).all())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号