def __init__(self, n_frames=2, n_steps=10):
"""Sets up layers"""
super(SqueezeNetLSTM, self).__init__()
self.n_frames = n_frames
self.n_steps = n_steps
self.pre_metadata_features = nn.Sequential(
nn.Conv2d(3 * 2 * self.n_frames, 16, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(16, 6, 12, 12)
)
self.post_metadata_features = nn.Sequential(
Fire(36, 8, 16, 16),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(32, 12, 24, 24),
Fire(48, 12, 24, 24),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(48, 16, 32, 32),
Fire(64, 16, 32, 32),
Fire(64, 24, 48, 48),
Fire(96, 24, 48, 48),
)
final_conv = nn.Conv2d(96, self.n_steps * 2, kernel_size=1)
self.pre_lstm_output = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
nn.AvgPool2d(kernel_size=3, stride=2),
)
self.lstms = nn.ModuleList([
nn.LSTM(16, 32, 2, batch_first=True),
nn.LSTM(32, 4, 1, batch_first=True)
])
for mod in self.modules():
if isinstance(mod, nn.Conv2d):
if mod is final_conv:
init.normal(mod.weight.data, mean=0.0, std=0.01)
else:
init.kaiming_uniform(mod.weight.data)
if mod.bias is not None:
mod.bias.data.zero_()
评论列表
文章目录