def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels,
weight_norm):
super().__init__()
self.dim = dim
self.q_levels = q_levels
ns_frame_samples = map(int, np.cumprod(frame_sizes))
self.frame_level_rnns = torch.nn.ModuleList([
FrameLevelRNN(
frame_size, n_frame_samples, n_rnn, dim, learn_h0, weight_norm
)
for (frame_size, n_frame_samples) in zip(
frame_sizes, ns_frame_samples
)
])
self.sample_level_mlp = SampleLevelMLP(
frame_sizes[0], dim, q_levels, weight_norm
)
评论列表
文章目录