def _SequentialBatchFFTGrad(op, grad):
if (grad.dtype == tf.complex64):
size = tf.cast(tf.shape(grad)[1], tf.float32)
return (sequential_batch_ifft(grad, op.get_attr("compute_size"))
* tf.complex(size, 0.))
else:
size = tf.cast(tf.shape(grad)[1], tf.float64)
return (sequential_batch_ifft(grad, op.get_attr("compute_size"))
* tf.complex(size, tf.zeros([], tf.float64)))
sequential_batch_fft_ops.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录