def create_model(self, model_input, vocab_size, num_mixtures=None,
l2_penalty=1e-8, sub_scope="", original_input=None, **unused_params):
num_supports = FLAGS.num_supports
input_size = model_input.shape.as_list()[1]
support_predictions = self.sub_model(model_input, num_supports, sub_scope=sub_scope+"-support")
main_relu = slim.fully_connected(
model_input,
input_size,
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(l2_penalty),
scope="main-relu-"+sub_scope)
main_input = tf.concat([main_relu, support_predictions], axis=1)
main_predictions = self.sub_model(main_input, vocab_size, sub_scope=sub_scope+"-main")
return {"predictions": main_predictions, "support_predictions": support_predictions}
评论列表
文章目录