model_def2.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Triplet_Loss_SBIR 作者: TuBui 项目源码 文件源码
def SketchANet(data_params, num_class = 20, val_mode = 0):
  """ our version of Sketch-A-Net
  data_params: batch_size, source, shape, scale, rot
  val_mode: 0 if this is train net, 1 if test net, 2 if deploy net
  """
  n = caffe.NetSpec()
  if val_mode == 2:
    n.data        = L.Input(name='data',
                           shape=dict(dim=[1,1,225,225]))
  else:
    n.data, n.label = L.Python(module = 'data_layer', layer = 'DataLayer',
                               ntop = 2, phase = val_mode,
                               param_str = str(data_params))
#==============================================================================
#     n.data, n.label = L.Data(batch_size=batch_size, backend=P.Data.LMDB, source=lmdb,
#                              transform_param=dict(scale=1./255), ntop=2)
#==============================================================================

  n.conv1, n.relu1 = conv_relu(n.data, 15, 64, stride = 3)
  n.pool1 = pooling(n.relu1,3, stride = 2)

  n.conv2, n.relu2 = conv_relu(n.pool1, 5, 128)
  n.pool2 = pooling(n.relu2,3, stride = 2)

  n.conv3, n.relu3 = conv_relu(n.pool2, 3, 256, pad = 1)
  n.conv4, n.relu4 = conv_relu(n.relu3, 3, 256, 1, 1)

  n.conv5, n.relu5 = conv_relu(n.relu4, 3, 256, pad=1)
  n.pool5 = pooling(n.relu5,3, stride = 2)

  n.fc6, n.relu6 = fc_relu(n.pool5, 512)
  if val_mode != 2:
    n.drop6 = L.Dropout(n.relu6, dropout_ratio = 0.55, in_place = True)

    n.fc7, n.relu7 = fc_relu(n.drop6, 512)
    n.drop7 = L.Dropout(n.relu7, dropout_ratio = 0.55, in_place = True)

    n.fc8 = fullconnect(n.drop7, num_class)
    n.loss = L.SoftmaxWithLoss(n.fc8, n.label)
  else: #deploy mode
    n.fc7, n.relu7 = fc_relu(n.relu6, 512)
    n.fc8 = fullconnect(n.relu7, num_class)

  if val_mode==1:
    n.accuracy = L.Accuracy(n.fc8, n.label, phase = val_mode)

  proto = n.to_proto()
  proto.name = 'SketchANet'
  return proto
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号