def __init__(self, input_size, hidden_size, max_length, use_bias=True):
super(BNLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.max_length = max_length
self.use_bias = use_bias
self.weight_ih = nn.Parameter(
torch.FloatTensor(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, 4 * hidden_size))
if use_bias:
self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
else:
self.register_parameter('bias', None)
# BN parameters
self.bn_ih = SeparatedBatchNorm1d(
num_features=4 * hidden_size, max_length=max_length)
self.bn_hh = SeparatedBatchNorm1d(
num_features=4 * hidden_size, max_length=max_length)
self.bn_c = SeparatedBatchNorm1d(
num_features=hidden_size, max_length=max_length)
self.reset_parameters()
评论列表
文章目录