def __init__(self, n_steps=10, n_frames=2):
super(Feedforward, self).__init__()
self.n_steps = n_steps
self.n_frames = n_frames
self.pre_metadata_features = nn.Sequential(
nn.Conv2d(3 * 2 * n_frames, 8, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
nn.Conv2d(8, 8, kernel_size=3, padding=1)
)
self.post_metadata_features = nn.Sequential(
nn.Conv2d(16, 12, kernel_size=3, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
nn.Conv2d(12, 12, kernel_size=3, padding=1),
nn.Conv2d(12, 16, kernel_size=3, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.Conv2d(16, 24, kernel_size=3, padding=1),
nn.Conv2d(24, 24, kernel_size=3, padding=1)
)
final_conv = nn.Conv2d(24, self.n_steps * 2, kernel_size=1)
self.final_output = nn.Sequential(
nn.Dropout(p=0.5),
final_conv,
# nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=5, stride=5)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is final_conv:
init.normal(m.weight.data, mean=0.0, std=0.01)
else:
init.kaiming_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
评论列表
文章目录