def matching_matrix(self,
model_input,
vocab_size,
l2_penalty=1e-8,
**unused_params):
max_frames = model_input.get_shape().as_list()[1]
num_features = model_input.get_shape().as_list()[2]
embedding_size = FLAGS.mm_label_embedding
model_input = tf.reshape(model_input, [-1, num_features])
frame_relu = slim.fully_connected(
model_input,
embedding_size,
activation_fn=tf.nn.relu,
biases_initializer=None,
weights_regularizer=slim.l2_regularizer(l2_penalty),
scope="mm_relu")
frame_activation = slim.fully_connected(
frame_relu,
embedding_size,
activation_fn=tf.nn.tanh,
biases_initializer=None,
weights_regularizer=slim.l2_regularizer(l2_penalty),
scope="mm_activation")
label_embedding = tf.get_variable("label_embedding", shape=[vocab_size,embedding_size],
dtype=tf.float32, initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.5),
regularizer=slim.l2_regularizer(l2_penalty), trainable=True)
mm_matrix = tf.einsum("ik,jk->ij", frame_activation, label_embedding)
mm_output = tf.reshape(mm_matrix, [-1,max_frames,vocab_size])
return mm_output
评论列表
文章目录