concat_op_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:complex_tf 作者: woodshop 项目源码 文件源码
def testGradientWithUnknownInputDim(self):
    with self.test_session(use_gpu=True):
      x = array_ops.placeholder(dtypes.complex64)
      y = array_ops.placeholder(dtypes.complex64)
      c = array_ops.concat([x, y], 2)

      output_shape = [10, 2, 9]
      grad_inp = (np.random.rand(*output_shape) + 
                  1j*np.random.rand(*output_shape)).astype(np.complex64)
      grad_tensor = constant_op.constant(
          [inp for inp in grad_inp.flatten()], shape=output_shape)

      grad = gradients_impl.gradients([c], [x, y], [grad_tensor])
      concated_grad = array_ops.concat(grad, 2)
      params = {
          x: (np.random.rand(10, 2, 3) + 
              1j*np.random.rand(10, 2, 3)).astype(np.complex64),
          y: (np.random.rand(10, 2, 6) + 
              1j*np.random.rand(10, 2, 6)).astype(np.complex64),
      }
      result = concated_grad.eval(feed_dict=params)

      self.assertAllEqual(result, grad_inp)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号