positional_cnn_deep_combine_chain_model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def add_positional_embedding(self, model_input, num_frames, l2_penalty=1e-8):
    batch_size, max_frames, num_features = model_input.get_shape().as_list()
    positional_embedding = tf.get_variable("positional_embedding", dtype=tf.float32,
                                shape=[1, max_frames, num_features], 
                                initializer=tf.zeros_initializer(),
                                regularizer=tf.contrib.layers.l2_regularizer(l2_penalty))
    mask = tf.sequence_mask(lengths=num_frames, maxlen=max_frames, dtype=tf.float32)
    model_input_with_positional_embedding = tf.einsum("ijk,ij->ijk", model_input + positional_embedding, mask)
    return model_input_with_positional_embedding
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号