lstm_positional_attention_max_pooling_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def get_mean_input(self, model_input, num_frames):
    batch_size, max_frames, num_features = model_input.get_shape().as_list()
    mask = tf.sequence_mask(lengths=num_frames, maxlen=max_frames, dtype=tf.float32)
    mean_input = tf.einsum("ijk,ij->ik", model_input, mask) / tf.expand_dims(tf.cast(num_frames, dtype=tf.float32), dim=1)
    tiled_mean_input = tf.tile(tf.expand_dims(mean_input, dim=1), multiples=[1,max_frames,1])
    return tiled_mean_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号