def __init__(self, n_in, n_out, batchnorm=False, preactivation=True, gate_style='add_split', kernel_size=7):
super(SMASHLayer, self).__init__()
self.n_out = n_out
self.n_in = n_in
self.batchnorm = batchnorm
self.preactivation = preactivation
self.gate_style = gate_style
''' may want to make n_in and n_out more dynamic here'''
self.op = nn.ModuleList([SMASHseq(n_in=n_in if not i%2 else n_out,
n_out=n_out,
dilation=1,
batchnorm=self.batchnorm,
preactivation=self.preactivation,
kernel_size=kernel_size)
for i in range(4)])
# Op represents the op definition, gate whether to use tanh-sig mult gates,
# dilation the individual dilation factors, and NL the particular
# activation to use at each ungated conv.
# Groups is currently unactivated, we'd need to make sure we slice differently
# if using variable group.
评论列表
文章目录