def __init__(self, frame_size, dim, q_levels, weight_norm):
super().__init__()
self.q_levels = q_levels
self.embedding = torch.nn.Embedding(
self.q_levels,
self.q_levels
)
self.input = torch.nn.Conv1d(
in_channels=q_levels,
out_channels=dim,
kernel_size=frame_size,
bias=False
)
init.kaiming_uniform(self.input.weight)
if weight_norm:
self.input = torch.nn.utils.weight_norm(self.input)
self.hidden = torch.nn.Conv1d(
in_channels=dim,
out_channels=dim,
kernel_size=1
)
init.kaiming_uniform(self.hidden.weight)
init.constant(self.hidden.bias, 0)
if weight_norm:
self.hidden = torch.nn.utils.weight_norm(self.hidden)
self.output = torch.nn.Conv1d(
in_channels=dim,
out_channels=q_levels,
kernel_size=1
)
nn.lecun_uniform(self.output.weight)
init.constant(self.output.bias, 0)
if weight_norm:
self.output = torch.nn.utils.weight_norm(self.output)
评论列表
文章目录