def embed_inputs(raw_inputs,
vocab_size,
embed_size,
reserve_zero=True,
name='embed',
embeddings=None):
with tf.variable_scope(name):
if embeddings is None:
shape = (vocab_size, embed_size)
embeddings = tf.get_variable(
'embeddings',
shape=(vocab_size, embed_size),
initializer=tf.orthogonal_initializer(),
dtype=tf.float32)
# If reserve_zero, make sure first row is always zeros
if reserve_zero:
zeros = tf.zeros((1, embed_size), dtype=tf.float32)
embeddings = tf.concat([zeros, embeddings], axis=0)
inputs = tf.nn.embedding_lookup(embeddings, raw_inputs)
return inputs
评论列表
文章目录