frame_level_models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def conv_block(self, input, out_size, layer, kernalsize=3, l2_penalty=1e-8, shortcut=False):
        in_shape = input.get_shape().as_list()
        if layer>0:
            filter_shape = [kernalsize, 1, in_shape[3], out_size]
        else:
            filter_shape = [kernalsize, in_shape[2], 1, out_size]
        W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W-%s" % layer)
        b = tf.Variable(tf.constant(0.1, shape=[out_size]), name="b-%s" % layer)
        tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(W))
        tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(b))
        if layer>0:
            conv = tf.nn.conv2d(input, W, strides=[1, 1, 1, 1], padding="SAME", name="conv-%s" % layer)
        else:
            conv = tf.nn.conv2d(input, W, strides=[1, 1, 1, 1], padding="VALID", name="conv-%s" % layer)
        if shortcut:
            shortshape = [1,1,in_shape[3], out_size]
            Ws = tf.Variable(tf.truncated_normal(shortshape, stddev=0.05), name="Ws-%s" % layer)
            tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(Ws))
            conv = conv + tf.nn.conv2d(input, Ws, strides=[1, 1, 1, 1], padding="SAME", name="conv-shortcut-%s" % layer)
        h = tf.nn.bias_add(conv, b)
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h, center=True, scale=True, epsilon=1e-5, decay=0.9), name="relu-%s" % layer)

        return h2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号