gen_embeddings.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def make_skipgram_softmax_loss(embeddings_matrix, vocabulary_size, vector_size):
    vectors = tf.get_variable('vectors', (vocabulary_size, vector_size), dtype=tf.float32, initializer=tf.constant_initializer(embeddings_matrix))
    minibatch = tf.placeholder(shape=(None, 2), dtype=tf.int32)

    center_word_vector = tf.nn.embedding_lookup(vectors, minibatch[:,0])
    yhat = tf.matmul(center_word_vector, vectors, transpose_b=True)

    predict_word = minibatch[:,1]
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=predict_word, logits=yhat)
    loss = tf.reduce_mean(loss)
    return vectors, minibatch, loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号