def create(cls, classes, maximum_tokens, embedding_size, lstm_units, dropout, bidirectional):
"""
Create a model that labels semantic relationships between text pairs.
The text pairs are passed in as two aligned matrices of size
(batch size, maximum embedding tokens, embedding size). They are generated by TextPairEmbeddingGenerator.
:param classes: the number of distinct classes to categorize
:type classes: int
:param maximum_tokens: maximum number of embedded tokens
:type maximum_tokens: int
:param embedding_size: size of the embedding vector
:type embedding_size: int
:param lstm_units: number of hidden units in the shared LSTM
:type lstm_units: int
:param dropout: dropout rate or None for no dropout
:type dropout: float or None
:param bidirectional: should the shared LSTM be bidirectional?
:type bidirectional: bool
:return: the created model
:rtype: TextPairClassifier
"""
# Create the model geometry.
input_shape = (maximum_tokens, embedding_size)
# Input two sets of aligned text pairs.
input_1 = Input(input_shape)
input_2 = Input(input_shape)
# Apply the same LSTM to each.
if bidirectional:
lstm = Bidirectional(LSTM(lstm_units), name="lstm")
else:
lstm = LSTM(lstm_units, name="lstm")
r1 = lstm(input_1)
r2 = lstm(input_2)
# Concatenate the embeddings with their product and squared difference.
p = multiply([r1, r2])
negative_r2 = Lambda(lambda x: -x)(r2)
d = add([r1, negative_r2])
q = multiply([d, d])
v = [r1, r2, p, q]
lstm_output = concatenate(v)
if dropout is not None:
lstm_output = Dropout(dropout, name="dropout")(lstm_output)
# A single-layer perceptron maps the concatenated vector to the labels. It has a number of hidden states equal
# to the square root of the length of the concatenated vector.
m = sum(t.shape[1].value for t in v)
perceptron = Dense(math.floor(math.sqrt(m)), activation="relu")(lstm_output)
logistic_regression = Dense(classes, activation="softmax", name="softmax")(perceptron)
model = Model([input_1, input_2], logistic_regression, "Text pair classifier")
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return cls(model)
评论列表
文章目录