def __init__(self, frame_size, n_frame_samples, n_rnn, dim,
learn_h0, weight_norm):
super().__init__()
self.frame_size = frame_size
self.n_frame_samples = n_frame_samples
self.dim = dim
h0 = torch.zeros(n_rnn, dim)
if learn_h0:
self.h0 = torch.nn.Parameter(h0)
else:
self.register_buffer('h0', torch.autograd.Variable(h0))
self.input_expand = torch.nn.Conv1d(
in_channels=n_frame_samples,
out_channels=dim,
kernel_size=1
)
init.kaiming_uniform(self.input_expand.weight)
init.constant(self.input_expand.bias, 0)
if weight_norm:
self.input_expand = torch.nn.utils.weight_norm(self.input_expand)
self.rnn = torch.nn.GRU(
input_size=dim,
hidden_size=dim,
num_layers=n_rnn,
batch_first=True
)
for i in range(n_rnn):
nn.concat_init(
getattr(self.rnn, 'weight_ih_l{}'.format(i)),
[nn.lecun_uniform, nn.lecun_uniform, nn.lecun_uniform]
)
init.constant(getattr(self.rnn, 'bias_ih_l{}'.format(i)), 0)
nn.concat_init(
getattr(self.rnn, 'weight_hh_l{}'.format(i)),
[nn.lecun_uniform, nn.lecun_uniform, init.orthogonal]
)
init.constant(getattr(self.rnn, 'bias_hh_l{}'.format(i)), 0)
self.upsampling = nn.LearnedUpsampling1d(
in_channels=dim,
out_channels=dim,
kernel_size=frame_size
)
init.uniform(
self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim)
)
init.constant(self.upsampling.bias, 0)
if weight_norm:
self.upsampling.conv_t = torch.nn.utils.weight_norm(
self.upsampling.conv_t
)
评论列表
文章目录