def SketchANet(data_params, num_class = 20, val_mode = 0):
""" our version of Sketch-A-Net
data_params: batch_size, source, shape, scale, rot
val_mode: 0 if this is train net, 1 if test net, 2 if deploy net
"""
n = caffe.NetSpec()
if val_mode == 2:
n.data = L.Input(name='data',
shape=dict(dim=[1,1,225,225]))
else:
n.data, n.label = L.Python(module = 'data_layer', layer = 'DataLayer',
ntop = 2, phase = val_mode,
param_str = str(data_params))
#==============================================================================
# n.data, n.label = L.Data(batch_size=batch_size, backend=P.Data.LMDB, source=lmdb,
# transform_param=dict(scale=1./255), ntop=2)
#==============================================================================
n.conv1, n.relu1 = conv_relu(n.data, 15, 64, stride = 3)
n.pool1 = pooling(n.relu1,3, stride = 2)
n.conv2, n.relu2 = conv_relu(n.pool1, 5, 128)
n.pool2 = pooling(n.relu2,3, stride = 2)
n.conv3, n.relu3 = conv_relu(n.pool2, 3, 256, pad = 1)
n.conv4, n.relu4 = conv_relu(n.relu3, 3, 256, 1, 1)
n.conv5, n.relu5 = conv_relu(n.relu4, 3, 256, pad=1)
n.pool5 = pooling(n.relu5,3, stride = 2)
n.fc6, n.relu6 = fc_relu(n.pool5, 512)
if val_mode != 2:
n.drop6 = L.Dropout(n.relu6, dropout_ratio = 0.55, in_place = True)
n.fc7, n.relu7 = fc_relu(n.drop6, 512)
n.drop7 = L.Dropout(n.relu7, dropout_ratio = 0.55, in_place = True)
n.fc8 = fullconnect(n.drop7, num_class)
n.loss = L.SoftmaxWithLoss(n.fc8, n.label)
else: #deploy mode
n.fc7, n.relu7 = fc_relu(n.relu6, 512)
n.fc8 = fullconnect(n.relu7, num_class)
if val_mode==1:
n.accuracy = L.Accuracy(n.fc8, n.label, phase = val_mode)
proto = n.to_proto()
proto.name = 'SketchANet'
return proto
评论列表
文章目录