embeddings_ops.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def categorical_variable(tensor_in, n_classes, embedding_size, name):
  """Creates an embedding for categorical variable with given number of classes.

  Args:
    tensor_in: Input tensor with class identifier (can be batch or
      N-dimensional).
    n_classes: Number of classes.
    embedding_size: Size of embedding vector to represent each class.
    name: Name of this categorical variable.
  Returns:
    Tensor of input shape, with additional dimension for embedding.

  Example:
    Calling categorical_variable([1, 2], 5, 10, "my_cat"), will return 2 x 10
    tensor, where each row is representation of the class.
  """
  with vs.variable_scope(name):
    embeddings = vs.get_variable(name + "_embeddings",
                                 [n_classes, embedding_size])
    return embedding_lookup(embeddings, tensor_in)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号