def predict(self, new_data, batch_size):
"""
predict for new data
"""
img_shape = (batch_size, 1, self.image_shape[2], self.image_shape[3])
conv_out = conv.conv2d(input=new_data, filters=self.W, filter_shape=self.filter_shape, image_shape=img_shape)
if self.non_linear=="tanh":
conv_out_tanh = T.tanh(conv_out + self.b.dimshuffle('x', 0, 'x', 'x'))
output = downsample.max_pool_2d(input=conv_out_tanh, ds=self.poolsize, ignore_border=True)
if self.non_linear=="relu":
conv_out_tanh = ReLU(conv_out + self.b.dimshuffle('x', 0, 'x', 'x'))
output = downsample.max_pool_2d(input=conv_out_tanh, ds=self.poolsize, ignore_border=True)
else:
pooled_out = downsample.max_pool_2d(input=conv_out, ds=self.poolsize, ignore_border=True)
output = pooled_out + self.b.dimshuffle('x', 0, 'x', 'x')
return output
conv_net_classes.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录