def PhaseShift_withConv(x, r, filters, kernel_size = (3, 3), stride = (1, 1)):
# output shape(batch, r*x_h, r*x_w, filters)
x = tcl.conv2d(x,
num_outputs = filters*r**2,
kernel_size = kernel_size,
stride = stride,
padding = 'SAME')
x = PhaseShift(x, r)
return x
评论列表
文章目录