layers.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:vessel-classification 作者: GlobalFishingWatch 项目源码 文件源码
def misconception_fishing(input,
                          window_size,
                          depths,
                          strides,
                          objective_function,
                          is_training,
                          pre_count=128,
                          post_count=128,
                          post_layers=1,
                          keep_prob=0.5,
                          internal_keep_prob=0.5,
                          other_objectives=()):

    _, layers = misconception_model(
        input,
        window_size,
        depths,
        strides,
        other_objectives,
        is_training,
        sub_count=post_count,
        sub_layers=2)

    expanded_layers = []
    for i, lyr in enumerate(layers):
        lyr = slim.conv2d(
            lyr,
            pre_count, [1, 1],
            activation_fn=tf.nn.relu,
            normalizer_fn=slim.batch_norm,
            normalizer_params={'is_training': is_training})
        expanded_layers.append(utility.repeat_tensor(lyr, 2**i))

    embedding = tf.add_n(expanded_layers)

    for _ in range(post_layers - 1):
        embedding = slim.conv2d(
            embedding,
            post_count, [1, 1],
            activation_fn=tf.nn.relu,
            normalizer_fn=slim.batch_norm,
            normalizer_params={'is_training': is_training})
    embedding = slim.conv2d(
        embedding,
        post_count, [1, 1],
        activation_fn=tf.nn.relu,
        normalizer_fn=None)
    embedding = slim.dropout(embedding, keep_prob, is_training=is_training)

    fishing_outputs = tf.squeeze(
        slim.conv2d(
            embedding, 1, [1, 1], activation_fn=None, normalizer_fn=None),
        squeeze_dims=[1, 3])

    return objective_function.build(fishing_outputs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号