keras_item2vec.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:MovieTaster-Open 作者: lujiaying 项目源码 文件源码
def skipgram_model(vocab_size, embedding_dim=100, paradigm='Functional'):
    # Sequential paradigm
    if paradigm == 'Sequential':
        target = Sequential()
        target.add(Embedding(vocab_size, embedding_dim, input_length=1))
        context = Sequential()
        context.add(Embedding(vocab_size, embedding_dim, input_length=1))

        # merge the pivot and context models
        model = Sequential()
        model.add(Merge([target, context], mode='dot'))
        model.add(Reshape((1,), input_shape=(1,1)))
        model.add(Activation('sigmoid'))
        model.compile(optimizer='adam', loss='binary_crossentropy')
        return model

    # Functional paradigm
    elif paradigm == 'Functional':
        target = Input(shape=(1,), name='target')
        context = Input(shape=(1,), name='context')
        #print target.shape, context.shape
        shared_embedding = Embedding(vocab_size, embedding_dim, input_length=1, name='shared_embedding')
        embedding_target = shared_embedding(target)
        embedding_context = shared_embedding(context)
        #print embedding_target.shape, embedding_context.shape

        merged_vector = dot([embedding_target, embedding_context], axes=-1)
        reshaped_vector = Reshape((1,), input_shape=(1,1))(merged_vector)
        #print merged_vector.shape
        prediction = Dense(1, input_shape=(1,), activation='sigmoid')(reshaped_vector)
        #print prediction.shape

        model = Model(inputs=[target, context], outputs=prediction)
        model.compile(optimizer='adam', loss='binary_crossentropy')
        return model

    else:
        print('paradigm error')
        return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号