ops.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:DeepVideo 作者: AniketBajpai 项目源码 文件源码
def conv3d(input_, output_shape, is_train,
           k=4, s=2, stddev=0.01,
           name='conv3d', with_w=False):
    k_d = k_h = k_w = k
    s_d = s_h = s_w = s
    with tf.variable_scope(name):
        weights = tf.get_variable('weights', [k_d, k_h, k_w, input_.get_shape()[-1], output_shape[-1]],
                                  initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv3d(input_, weights, strides=[1, s_d, s_h, s_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
        bn = tf.contrib.layers.batch_norm(conv, center=True, scale=True, decay=0.9,
                                          is_training=is_train, updates_collections=None)
        out = lrelu(bn, name='lrelu')

        if with_w:
            return out, weights, biases
        else:
            return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号