nn.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:pyprob 作者: probprog 项目源码 文件源码
def __init__(self, input_example_non_batch, output_dim, reshape=None, dropout=0):
        super(ObserveEmbeddingCNN3D4C, self).__init__()
        self.reshape = reshape
        if self.reshape is not None:
            input_example_non_batch = input_example_non_batch.view(self.reshape)
            self.reshape.insert(0, -1) # For correct handling of the batch dimension in self.forward
        if input_example_non_batch.dim() == 3:
            self.input_sample = input_example_non_batch.unsqueeze(0).cpu()
        elif input_example_non_batch.dim() == 4:
            self.input_sample = input_example_non_batch.cpu()
        else:
            util.logger.log('ObserveEmbeddingCNN3D4C: Expecting a 4d input_example_non_batch (num_channels x depth x height x width) or a 3d input_example_non_batch (depth x height x width). Received: {0}'.format(input_example_non_batch.size()))
        self.input_channels = self.input_sample.size(0)
        self.output_dim = output_dim
        self.conv1 = nn.Conv3d(self.input_channels, 64, 3)
        self.conv2 = nn.Conv3d(64, 64, 3)
        self.conv3 = nn.Conv3d(64, 128, 3)
        self.conv4 = nn.Conv3d(128, 128, 3)
        self.drop = nn.Dropout(dropout)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号