def skipgram_model(vocab_size, embedding_dim=100, paradigm='Functional'):
# Sequential paradigm
if paradigm == 'Sequential':
target = Sequential()
target.add(Embedding(vocab_size, embedding_dim, input_length=1))
context = Sequential()
context.add(Embedding(vocab_size, embedding_dim, input_length=1))
# merge the pivot and context models
model = Sequential()
model.add(Merge([target, context], mode='dot'))
model.add(Reshape((1,), input_shape=(1,1)))
model.add(Activation('sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
# Functional paradigm
elif paradigm == 'Functional':
target = Input(shape=(1,), name='target')
context = Input(shape=(1,), name='context')
#print target.shape, context.shape
shared_embedding = Embedding(vocab_size, embedding_dim, input_length=1, name='shared_embedding')
embedding_target = shared_embedding(target)
embedding_context = shared_embedding(context)
#print embedding_target.shape, embedding_context.shape
merged_vector = dot([embedding_target, embedding_context], axes=-1)
reshaped_vector = Reshape((1,), input_shape=(1,1))(merged_vector)
#print merged_vector.shape
prediction = Dense(1, input_shape=(1,), activation='sigmoid')(reshaped_vector)
#print prediction.shape
model = Model(inputs=[target, context], outputs=prediction)
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
else:
print('paradigm error')
return None
评论列表
文章目录