linear_regression_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def create_model(self, model_input, vocab_size, l2_penalty=1e-8, original_input=None, **unused_params):
    """Creates a linear regression model.

    Args:
      model_input: 'batch' x 'num_features' x 'num_methods' matrix of input features.
      vocab_size: The number of classes in the dataset.

    Returns:
      A dictionary with a tensor containing the probability predictions of the
      model in the 'predictions' key. The dimensions of the tensor are
      batch_size x num_classes."""
    num_methods = model_input.get_shape().as_list()[-1]
    weight = tf.get_variable("ensemble_weight", 
        shape=[num_methods],
        regularizer=slim.l2_regularizer(l2_penalty))
    weight = tf.nn.softmax(weight)
    output = tf.einsum("ijk,k->ij", model_input, weight)
    return {"predictions": output}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号