multitask_divergence_deep_combine_chain_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def create_model(self, model_input, vocab_size, num_mixtures=None,
                   l2_penalty=1e-8, sub_scope="ddcc", original_input=None, 
                   dropout=False, keep_prob=None, noise_level=None,
                   num_frames=None, **unused_params):
    num_supports = FLAGS.num_supports
    num_models = FLAGS.divergence_model_count

    support_predictions = []
    for i in xrange(num_models):
      sub_prediction = self.sub_chain_model(model_input,vocab_size, num_mixtures, 
                                      l2_penalty, sub_scope+"%d"%i, original_input,
                                      dropout, keep_prob, noise_level)
      support_predictions.append(sub_prediction)
    support_predictions = tf.stack(support_predictions, axis=1)
    main_predictions = tf.reduce_mean(support_predictions, axis=1)
    return {"predictions": main_predictions, "support_predictions": support_predictions}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号