conv.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:sesame-paste-noodle 作者: aissehust 项目源码 文件源码
def forward(self, inputtensor):
        inputimage = inputtensor[0]
        #print('conv2d.forward.type: {}'.format(inputimage.ndim))
        if self.dc == 0.0:
            pass
        else:
            if 0 <self.dc <=1:
                _srng = RandomStreams(np.random.randint(1, 2147462579))
                one = T.constant(1)
                retain_prob = one - self.dc
                mask_shape = self.w.shape
                mask = _srng.binomial(mask_shape, p=retain_prob,
                                           dtype=self.w.dtype)
                self.w = self.w * mask
            else:
                raise IndexError

        l3conv = T.nnet.conv2d(inputimage,
                               self.w,
                               border_mode=self.border,
                               subsample=self.subsample)
        if self.need_bias:            
            return ((l3conv+self.b.dimshuffle('x', 0, 'x', 'x')), )
        else:
            return (l3conv, )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号