def __init__(self, input_example_non_batch, output_dim, reshape=None, dropout=0):
super(ObserveEmbeddingCNN3D4C, self).__init__()
self.reshape = reshape
if self.reshape is not None:
input_example_non_batch = input_example_non_batch.view(self.reshape)
self.reshape.insert(0, -1) # For correct handling of the batch dimension in self.forward
if input_example_non_batch.dim() == 3:
self.input_sample = input_example_non_batch.unsqueeze(0).cpu()
elif input_example_non_batch.dim() == 4:
self.input_sample = input_example_non_batch.cpu()
else:
util.logger.log('ObserveEmbeddingCNN3D4C: Expecting a 4d input_example_non_batch (num_channels x depth x height x width) or a 3d input_example_non_batch (depth x height x width). Received: {0}'.format(input_example_non_batch.size()))
self.input_channels = self.input_sample.size(0)
self.output_dim = output_dim
self.conv1 = nn.Conv3d(self.input_channels, 64, 3)
self.conv2 = nn.Conv3d(64, 64, 3)
self.conv3 = nn.Conv3d(64, 128, 3)
self.conv4 = nn.Conv3d(128, 128, 3)
self.drop = nn.Dropout(dropout)
评论列表
文章目录