model_def2.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:Triplet_Loss_SBIR 作者: TuBui 项目源码 文件源码
def pretrain_image(data_params, num_class = 20, mode = 'train',learn_all=True):
  """ our version of Sketch-A-Net
  data_params: batch_size, source, shape, scale, rot
  val_mode: 0 if this is train net, 1 if test net, 2 if deploy net
  """
  param = learned_param if learn_all else frozen_param
  n = caffe.NetSpec()
  if mode == 'deploy':
    n.data        = L.Input(name='data',
                           shape=dict(dim=[1,1,225,225]))
  else:
    n.data, n.label = L.Python(module = 'data_layer', layer = 'DataLayer',
                               ntop = 2, phase = Train_Mode[mode],
                               param_str = str(data_params))

  n.conv1_p, n.relu1_p = conv_relu(n.data,15,64,3,param=param,name_prefix='conv1_p')
  n.pool1_p = pooling(n.relu1_p,3, 2)

  n.conv2_p, n.relu2_p = conv_relu(n.pool1_p,5,128,param=param,name_prefix='conv2_p')
  n.pool2_p = pooling(n.relu2_p,3,2)

  n.conv3_p, n.relu3_p = conv_relu(n.pool2_p,3,256,param=param,name_prefix='conv3_p')
  n.conv4, n.relu4 = conv_relu(n.relu3_p,3,256,param=param,name_prefix='conv4')

  n.conv5, n.relu5 = conv_relu(n.relu4,3,256,param=param,name_prefix='conv5')
  n.pool5 = pooling(n.relu5,3,2)

  n.fc6, n.relu6 = fc_relu(n.pool5, 512,param=param,name_prefix='fc6')

  if mode == 'train':
    n.drop6 = fc7input = L.Dropout(n.relu6, dropout_ratio=0.55,in_place=True)
  else:
    fc7input = n.relu6;
  n.fc7, n.relu7 = fc_relu(fc7input, 512, param=param,name_prefix='fc7')
  if mode =='train':
    n.drop7 = fc8input= L.Dropout(n.relu7, dropout_ratio = 0.55,in_place=True)
  else:
    fc8input = n.relu7
  #n.feat8_r = fullconnect(fc8input, 100,param=learned_param,name_prefix='fc8_r')
  n.feat8 = fullconnect(fc8input, num_class,param=learned_param,name_prefix='fc8')

  if mode != 'deploy':
    n.loss = L.SoftmaxWithLoss(n.feat8, n.label)

  if mode=='test': #validation
    n.accuracy = L.Accuracy(n.feat8, n.label, phase = Train_Mode[mode])

  proto = n.to_proto()
  proto.name = 'SketchANet'
  return proto
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号