mm_lstm_memory_model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def matching_matrix(self, 
          model_input, 
          vocab_size,
          l2_penalty=1e-8, 
          **unused_params):
    max_frames = model_input.get_shape().as_list()[1]
    num_features = model_input.get_shape().as_list()[2]
    embedding_size = FLAGS.mm_label_embedding

    model_input = tf.reshape(model_input, [-1, num_features])

    frame_relu = slim.fully_connected(
        model_input,
        embedding_size,
        activation_fn=tf.nn.relu,
        biases_initializer=None,
        weights_regularizer=slim.l2_regularizer(l2_penalty),
        scope="mm_relu")

    frame_activation = slim.fully_connected(
        frame_relu,
        embedding_size,
        activation_fn=tf.nn.tanh,
        biases_initializer=None,
        weights_regularizer=slim.l2_regularizer(l2_penalty),
        scope="mm_activation")

    label_embedding = tf.get_variable("label_embedding", shape=[vocab_size,embedding_size],
        dtype=tf.float32, initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.5),
        regularizer=slim.l2_regularizer(l2_penalty), trainable=True)

    mm_matrix = tf.einsum("ik,jk->ij", frame_activation, label_embedding)
    mm_output = tf.reshape(mm_matrix, [-1,max_frames,vocab_size])
    return mm_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号