def categorical_variable(tensor_in, n_classes, embedding_size, name):
"""Creates an embedding for categorical variable with given number of classes.
Args:
tensor_in: Input tensor with class identifier (can be batch or
N-dimensional).
n_classes: Number of classes.
embedding_size: Size of embedding vector to represent each class.
name: Name of this categorical variable.
Returns:
Tensor of input shape, with additional dimension for embedding.
Example:
Calling categorical_variable([1, 2], 5, 10, "my_cat"), will return 2 x 10
tensor, where each row is representation of the class.
"""
with vs.variable_scope(name):
embeddings = vs.get_variable(name + "_embeddings",
[n_classes, embedding_size])
return embedding_lookup(embeddings, tensor_in)
评论列表
文章目录