def __init__(self, n_steps=10, n_frames=2):
super(SqueezeNet, self).__init__()
self.n_steps = n_steps
self.n_frames = n_frames
self.pre_metadata_features = nn.Sequential(
nn.Conv2d(3 * 2 * self.n_frames, 16, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(16, 4, 8, 8)
)
self.post_metadata_features = nn.Sequential(
Fire(24, 6, 12, 12),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(24, 8, 16, 16),
Fire(32, 8, 16, 16),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(32, 12, 24, 24),
Fire(48, 12, 24, 24),
Fire(48, 16, 32, 32),
Fire(64, 16, 32, 32),
)
final_conv = nn.Conv2d(64, self.n_steps * 2, kernel_size=1)
self.final_output = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
# nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=5, stride=5)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal(m.weight.data, mean=0.0, std=0.01)
else:
init.kaiming_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
评论列表
文章目录