def pretrain_image(data_params, num_class = 20, mode = 'train',learn_all=True):
""" our version of Sketch-A-Net
data_params: batch_size, source, shape, scale, rot
val_mode: 0 if this is train net, 1 if test net, 2 if deploy net
"""
param = learned_param if learn_all else frozen_param
n = caffe.NetSpec()
if mode == 'deploy':
n.data = L.Input(name='data',
shape=dict(dim=[1,1,225,225]))
else:
n.data, n.label = L.Python(module = 'data_layer', layer = 'DataLayer',
ntop = 2, phase = Train_Mode[mode],
param_str = str(data_params))
n.conv1_p, n.relu1_p = conv_relu(n.data,15,64,3,param=param,name_prefix='conv1_p')
n.pool1_p = pooling(n.relu1_p,3, 2)
n.conv2_p, n.relu2_p = conv_relu(n.pool1_p,5,128,param=param,name_prefix='conv2_p')
n.pool2_p = pooling(n.relu2_p,3,2)
n.conv3_p, n.relu3_p = conv_relu(n.pool2_p,3,256,param=param,name_prefix='conv3_p')
n.conv4, n.relu4 = conv_relu(n.relu3_p,3,256,param=param,name_prefix='conv4')
n.conv5, n.relu5 = conv_relu(n.relu4,3,256,param=param,name_prefix='conv5')
n.pool5 = pooling(n.relu5,3,2)
n.fc6, n.relu6 = fc_relu(n.pool5, 512,param=param,name_prefix='fc6')
if mode == 'train':
n.drop6 = fc7input = L.Dropout(n.relu6, dropout_ratio=0.55,in_place=True)
else:
fc7input = n.relu6;
n.fc7, n.relu7 = fc_relu(fc7input, 512, param=param,name_prefix='fc7')
if mode =='train':
n.drop7 = fc8input= L.Dropout(n.relu7, dropout_ratio = 0.55,in_place=True)
else:
fc8input = n.relu7
#n.feat8_r = fullconnect(fc8input, 100,param=learned_param,name_prefix='fc8_r')
n.feat8 = fullconnect(fc8input, num_class,param=learned_param,name_prefix='fc8')
if mode != 'deploy':
n.loss = L.SoftmaxWithLoss(n.feat8, n.label)
if mode=='test': #validation
n.accuracy = L.Accuracy(n.feat8, n.label, phase = Train_Mode[mode])
proto = n.to_proto()
proto.name = 'SketchANet'
return proto
评论列表
文章目录