model.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def conv_net_kelz(inputs):
  """Builds the ConvNet from Kelz 2016."""
  with slim.arg_scope(
      [slim.conv2d, slim.fully_connected],
      activation_fn=tf.nn.relu,
      weights_initializer=tf.contrib.layers.variance_scaling_initializer(
          factor=2.0, mode='FAN_AVG', uniform=True)):
    net = slim.conv2d(inputs, 32, [3, 3], scope='conv1')

    net = slim.conv2d(
        net, 32, [3, 3], scope='conv2', normalizer_fn=slim.batch_norm)
    net = slim.max_pool2d(net, [1, 2], stride=[1, 2], scope='pool2')
    net = slim.dropout(net, 0.25, scope='dropout2')

    net = slim.conv2d(net, 64, [3, 3], scope='conv3')
    net = slim.max_pool2d(net, [1, 2], stride=[1, 2], scope='pool3')
    net = slim.dropout(net, 0.25, scope='dropout3')

    # Flatten while preserving batch and time dimensions.
    dims = tf.shape(net)
    net = tf.reshape(net, (dims[0], dims[1],
                           net.shape[2].value * net.shape[3].value), 'flatten4')

    net = slim.fully_connected(net, 512, scope='fc5')
    net = slim.dropout(net, 0.5, scope='dropout5')

    return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号