def build_attention_model(opts, vocab_size=0, maxnum=50, maxlen=50, embedd_dim=50, embedding_weights=None, verbose=False, init_mean_value=None):
N = maxnum
L = maxlen
logger = get_logger('Build attention pooling model')
logger.info("Model parameters: max_sentnum = %d, max_sentlen = %d, embedding dim = %s, lstm_units = %s, drop rate = %s, l2 = %s" % (N, L, embedd_dim,
opts.lstm_units, opts.dropout, opts.l2_value))
word_input = Input(shape=(N*L,), dtype='int32', name='word_input')
x = Embedding(output_dim=embedd_dim, input_dim=vocab_size, input_length=N*L, weights=embedding_weights, name='x')(word_input)
drop_x = Dropout(opts.dropout, name='drop_x')(x)
resh_W = Reshape((N, L, embedd_dim), name='resh_W')(drop_x)
z = TimeDistributed(LSTM(opts.lstm_units, return_sequences=True), name='z')(resh_W)
avg_z = TimeDistributed(GlobalAveragePooling1D(), name='avg_z')(z)
hz = LSTM(opts.lstm_units, return_sequences=True, name='hz')(avg_z)
# avg_h = MeanOverTime(mask_zero=True, name='avg_h')(hz)
# avg_hz = GlobalAveragePooling1D(name='avg_hz')(hz)
attent_hz = Attention(name='attent_hz')(hz)
y = Dense(output_dim=1, activation='sigmoid', name='output')(attent_hz)
model = Model(input=word_input, output=y)
if opts.init_bias and init_mean_value:
logger.info("Initialise output layer bias with log(y_mean/1-y_mean)")
bias_value = (np.log(init_mean_value) - np.log(1 - init_mean_value)).astype(K.floatx())
model.layers[-1].b.set_value(bias_value)
if verbose:
model.summary()
start_time = time.time()
model.compile(loss='mse', optimizer='rmsprop')
total_time = time.time() - start_time
logger.info("Model compiled in %.4f s" % total_time)
return model
评论列表
文章目录