def create_model(self, model_input, vocab_size, num_mixtures=None,
l2_penalty=1e-8, sub_scope="", original_input=None, **unused_params):
num_supports = FLAGS.num_supports
num_layers = FLAGS.hidden_chain_layers
relu_cells = FLAGS.hidden_chain_relu_cells
next_input = model_input
support_predictions = []
for layer in xrange(num_layers):
sub_relu = slim.fully_connected(
next_input,
relu_cells,
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(l2_penalty),
scope=sub_scope+"relu-%d"%layer)
sub_prediction = self.sub_model(sub_relu, vocab_size, sub_scope=sub_scope+"prediction-%d"%layer)
relu_norm = tf.nn.l2_normalize(sub_relu, dim=1)
next_input = tf.concat([model_input, relu_norm], axis=1)
support_predictions.append(sub_prediction)
main_predictions = self.sub_model(next_input, vocab_size, sub_scope=sub_scope+"-main")
support_predictions = tf.concat(support_predictions, axis=1)
return {"predictions": main_predictions, "support_predictions": support_predictions}
评论列表
文章目录