layers.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tag_srl 作者: danfriedman0 项目源码 文件源码
def embed_inputs(raw_inputs,
                 vocab_size,
                 embed_size,
                 reserve_zero=True,
                 name='embed',
                 embeddings=None):
    with tf.variable_scope(name):
        if embeddings is None:
            shape = (vocab_size, embed_size)
            embeddings = tf.get_variable(
                'embeddings',
                shape=(vocab_size, embed_size),
                initializer=tf.orthogonal_initializer(),
                dtype=tf.float32)

            # If reserve_zero, make sure first row is always zeros
            if reserve_zero:
                zeros = tf.zeros((1, embed_size), dtype=tf.float32)
                embeddings = tf.concat([zeros, embeddings], axis=0)

        inputs = tf.nn.embedding_lookup(embeddings, raw_inputs)
        return inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号