rot_mnist12K_model.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:TI-pooling 作者: dlaptev 项目源码 文件源码
def define_model(x,
                 keep_prob,
                 number_of_classes,
                 number_of_filters,
                 number_of_fc_features):
  splitted = tf.unpack(x, axis=4)
  branches = []
  with tf.variable_scope('branches') as scope:  
    for index, tensor_slice in enumerate(splitted):
      branches.append(single_branch(splitted[index],
                      number_of_filters,
                      number_of_fc_features))
      if (index == 0):
        scope.reuse_variables()
    concatenated = tf.pack(branches, axis=2)
    ti_pooled = tf.reduce_max(concatenated, reduction_indices=[2])
    drop = tf.nn.dropout(ti_pooled, keep_prob)
  with tf.variable_scope('fc2'):
    logits = fc(drop,
                [number_of_fc_features, number_of_classes],
                [number_of_classes])
  return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号