def create_model(self,
model_input,
vocab_size,
keep_prob,
num_mixtures=None,
l2_penalty=1e-8,
**unused_params):
num_mixtures = num_mixtures or FLAGS.moe_num_mixtures
new_model_input = slim.dropout(model_input,keep_prob=keep_prob)
#new_model_input = model_input
gate_activations = slim.fully_connected(
new_model_input,
vocab_size * (num_mixtures + 1),
activation_fn=None,
biases_initializer=None,
weights_regularizer=slim.l2_regularizer(l2_penalty),
scope="gates")
expert_activations = slim.fully_connected(
new_model_input,
vocab_size * num_mixtures,
activation_fn=None,
weights_regularizer=slim.l2_regularizer(l2_penalty),
scope="experts")
gating_distribution = tf.nn.softmax(tf.reshape(
gate_activations,
[-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1)
expert_distribution = tf.nn.sigmoid(tf.reshape(
expert_activations,
[-1, num_mixtures])) # (Batch * #Labels) x num_mixtures
final_probabilities_by_class_and_batch = tf.reduce_sum(
gating_distribution[:, :num_mixtures] * expert_distribution, 1)
final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
[-1, vocab_size])
return {"predictions": final_probabilities,
"zhaofeatures": model_input}
评论列表
文章目录