def pretrain_sketch(data_params, num_class = 20, mode = 'train',learn_all=True):
""" our version of Sketch-A-Net
data_params: batch_size, source, shape, scale, rot
val_mode: 0 if this is train net, 1 if test net, 2 if deploy net
"""
param = learned_param if learn_all else frozen_param
n = caffe.NetSpec()
if mode == 'deploy':
n.data = L.Input(name='data',
shape=dict(dim=[1,1,225,225]))
else:
n.data, n.label = L.Python(module = 'data_layer', layer = 'DataLayer',
ntop = 2, phase = Train_Mode[mode],
param_str = str(data_params))
n.conv1_a, n.relu1_a = conv_relu(n.data,15,64,3,param=param,name_prefix='conv1_a')
n.pool1_a = pooling(n.relu1_a,3, 2)
n.conv2_a, n.relu2_a = conv_relu(n.pool1_a,5,128,param=param,name_prefix='conv2_a')
n.pool2_a = pooling(n.relu2_a,3,2)
n.conv3_a, n.relu3_a = conv_relu(n.pool2_a,3,256,param=param,name_prefix='conv3_a')
n.conv4_s, n.relu4_s = conv_relu(n.relu3_a,3,256,param=param,name_prefix='conv4_s')
n.conv5_s, n.relu5_s = conv_relu(n.relu4_s,3,256,param=param,name_prefix='conv5_s')
n.pool5_s = pooling(n.relu5_s,3,2)
n.fc6_s, n.relu6_s = fc_relu(n.pool5_s, 512,param=param,name_prefix='fc6_s')
if mode == 'train':
n.drop6_s = fc7input = L.Dropout(n.relu6_s, dropout_ratio=0.55,in_place=True)
else:
fc7input = n.relu6_s;
n.fc7_s, n.relu7_s = fc_relu(fc7input, 512, param=param,name_prefix='fc7_s')
if mode =='train':
n.drop7_s = fc8input= L.Dropout(n.relu7_s, dropout_ratio = 0.55,in_place=True)
else:
fc8input = n.relu7_s
#n.feat8_r_s = fullconnect(fc8input, 100,param=learned_param,name_prefix='fc8_r_s')
n.feat8_s = fullconnect(fc8input, num_class,param=learned_param,name_prefix='fc8_s')
if mode != 'deploy':
n.loss = L.SoftmaxWithLoss(n.feat8_s, n.label)
if mode=='test': #validation
n.accuracy = L.Accuracy(n.feat8_s, n.label, phase = Train_Mode[mode])
proto = n.to_proto()
proto.name = 'SketchANet'
return proto
评论列表
文章目录