def __init__(self, ch_in, ch_out, ksize, stride=1, pad=0, activation=F.relu):
if hasattr(ksize, '__getitem__'):
kh, kw = ksize
else:
kh, kw = ksize, ksize
super(WNConv2D, self).__init__(
wn_conv=WeightNormalization(F.convolution_2d, (ch_out, ch_in, kh, kw), stride=stride, pad=pad),
)
self.activation=activation
评论列表
文章目录